Skip to content

Commit

Permalink
[OpenMP][OMPIRBuilder] Migrate MapCombinedInfoTy from Clang to OpenMP…
Browse files Browse the repository at this point in the history
…IRBuilder

This patch migrates the MapCombinedInfoTy from Clang codegen to OpenMPIRBuilder.

Differential Revision: https://reviews.llvm.org/D149666
  • Loading branch information
TIFitis committed May 4, 2023
1 parent 147a561 commit 35309db
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 60 deletions.
95 changes: 35 additions & 60 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6831,67 +6831,30 @@ class MappableExprsHandler {
const Expr *getMapExpr() const { return MapExpr; }
};

/// Class that associates information with a base pointer to be passed to the
/// runtime library.
class BasePointerInfo {
/// The base pointer.
llvm::Value *Ptr = nullptr;
/// The base declaration that refers to this device pointer, or null if
/// there is none.
const ValueDecl *DevPtrDecl = nullptr;

public:
BasePointerInfo(llvm::Value *Ptr, const ValueDecl *DevPtrDecl = nullptr)
: Ptr(Ptr), DevPtrDecl(DevPtrDecl) {}
llvm::Value *operator*() const { return Ptr; }
const ValueDecl *getDevicePtrDecl() const { return DevPtrDecl; }
void setDevicePtrDecl(const ValueDecl *D) { DevPtrDecl = D; }
};

using MapBaseValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
using MapValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
using MapFlagsArrayTy = llvm::OpenMPIRBuilder::MapFlagsArrayTy;
using MapDimArrayTy = llvm::OpenMPIRBuilder::MapDimArrayTy;
using MapNonContiguousArrayTy =
llvm::OpenMPIRBuilder::MapNonContiguousArrayTy;
using MapExprsArrayTy = SmallVector<MappingExprInfo, 4>;
using MapBaseValuesArrayTy = SmallVector<BasePointerInfo, 4>;
using MapValuesArrayTy = SmallVector<llvm::Value *, 4>;
using MapFlagsArrayTy = SmallVector<OpenMPOffloadMappingFlags, 4>;
using MapMappersArrayTy = SmallVector<const ValueDecl *, 4>;
using MapDimArrayTy = SmallVector<uint64_t, 4>;
using MapNonContiguousArrayTy = SmallVector<MapValuesArrayTy, 4>;
using MapValueDeclsArrayTy = SmallVector<const ValueDecl *, 4>;

/// This structure contains combined information generated for mappable
/// clauses, including base pointers, pointers, sizes, map types, user-defined
/// mappers, and non-contiguous information.
struct MapCombinedInfoTy {
struct StructNonContiguousInfo {
bool IsNonContiguous = false;
MapDimArrayTy Dims;
MapNonContiguousArrayTy Offsets;
MapNonContiguousArrayTy Counts;
MapNonContiguousArrayTy Strides;
};
struct MapCombinedInfoTy : llvm::OpenMPIRBuilder::MapInfosTy {
MapExprsArrayTy Exprs;
MapBaseValuesArrayTy BasePointers;
MapValuesArrayTy Pointers;
MapValuesArrayTy Sizes;
MapFlagsArrayTy Types;
MapMappersArrayTy Mappers;
StructNonContiguousInfo NonContigInfo;
MapValueDeclsArrayTy Mappers;
MapValueDeclsArrayTy DevicePtrDecls;

/// Append arrays in \a CurInfo.
void append(MapCombinedInfoTy &CurInfo) {
Exprs.append(CurInfo.Exprs.begin(), CurInfo.Exprs.end());
BasePointers.append(CurInfo.BasePointers.begin(),
CurInfo.BasePointers.end());
Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end());
Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end());
Types.append(CurInfo.Types.begin(), CurInfo.Types.end());
DevicePtrDecls.append(CurInfo.DevicePtrDecls.begin(),
CurInfo.DevicePtrDecls.end());
Mappers.append(CurInfo.Mappers.begin(), CurInfo.Mappers.end());
NonContigInfo.Dims.append(CurInfo.NonContigInfo.Dims.begin(),
CurInfo.NonContigInfo.Dims.end());
NonContigInfo.Offsets.append(CurInfo.NonContigInfo.Offsets.begin(),
CurInfo.NonContigInfo.Offsets.end());
NonContigInfo.Counts.append(CurInfo.NonContigInfo.Counts.begin(),
CurInfo.NonContigInfo.Counts.end());
NonContigInfo.Strides.append(CurInfo.NonContigInfo.Strides.begin(),
CurInfo.NonContigInfo.Strides.end());
llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
}
};

Expand Down Expand Up @@ -7638,6 +7601,7 @@ class MappableExprsHandler {
assert(Size && "Failed to determine structure size");
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
CombinedInfo.BasePointers.push_back(BP.getPointer());
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.Pointers.push_back(LB.getPointer());
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
Size, CGF.Int64Ty, /*isSigned=*/true));
Expand All @@ -7649,6 +7613,7 @@ class MappableExprsHandler {
}
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
CombinedInfo.BasePointers.push_back(BP.getPointer());
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.Pointers.push_back(LB.getPointer());
Size = CGF.Builder.CreatePtrDiff(
CGF.Int8Ty, CGF.Builder.CreateConstGEP(HB, 1).getPointer(),
Expand All @@ -7666,6 +7631,7 @@ class MappableExprsHandler {
(Next == CE && MapType != OMPC_MAP_unknown)) {
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
CombinedInfo.BasePointers.push_back(BP.getPointer());
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.Pointers.push_back(LB.getPointer());
CombinedInfo.Sizes.push_back(
CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true));
Expand Down Expand Up @@ -8168,7 +8134,8 @@ class MappableExprsHandler {
[&UseDeviceDataCombinedInfo](const ValueDecl *VD, llvm::Value *Ptr,
CodeGenFunction &CGF) {
UseDeviceDataCombinedInfo.Exprs.push_back(VD);
UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr, VD);
UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr);
UseDeviceDataCombinedInfo.DevicePtrDecls.emplace_back(VD);
UseDeviceDataCombinedInfo.Pointers.push_back(Ptr);
UseDeviceDataCombinedInfo.Sizes.push_back(
llvm::Constant::getNullValue(CGF.Int64Ty));
Expand Down Expand Up @@ -8337,8 +8304,7 @@ class MappableExprsHandler {
assert(RelevantVD &&
"No relevant declaration related with device pointer??");

CurInfo.BasePointers[CurrentBasePointersIdx].setDevicePtrDecl(
RelevantVD);
CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
CurInfo.Types[CurrentBasePointersIdx] |=
OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
}
Expand Down Expand Up @@ -8377,7 +8343,8 @@ class MappableExprsHandler {
OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
}
CurInfo.Exprs.push_back(L.VD);
CurInfo.BasePointers.emplace_back(BasePtr, L.VD);
CurInfo.BasePointers.emplace_back(BasePtr);
CurInfo.DevicePtrDecls.emplace_back(L.VD);
CurInfo.Pointers.push_back(Ptr);
CurInfo.Sizes.push_back(
llvm::Constant::getNullValue(this->CGF.Int64Ty));
Expand Down Expand Up @@ -8472,6 +8439,7 @@ class MappableExprsHandler {
CombinedInfo.Exprs.push_back(VD);
// Base is the base of the struct
CombinedInfo.BasePointers.push_back(PartialStruct.Base.getPointer());
CombinedInfo.DevicePtrDecls.push_back(nullptr);
// Pointer is the address of the lowest element
llvm::Value *LB = LBAddr.getPointer();
const CXXMethodDecl *MD =
Expand Down Expand Up @@ -8593,6 +8561,7 @@ class MappableExprsHandler {
VDLVal.getPointer(CGF));
CombinedInfo.Exprs.push_back(VD);
CombinedInfo.BasePointers.push_back(ThisLVal.getPointer(CGF));
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.Pointers.push_back(ThisLValVal.getPointer(CGF));
CombinedInfo.Sizes.push_back(
CGF.Builder.CreateIntCast(CGF.getTypeSize(CGF.getContext().VoidPtrTy),
Expand All @@ -8619,6 +8588,7 @@ class MappableExprsHandler {
VDLVal.getPointer(CGF));
CombinedInfo.Exprs.push_back(VD);
CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.Pointers.push_back(VarLValVal.getPointer(CGF));
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
CGF.getTypeSize(
Expand All @@ -8630,6 +8600,7 @@ class MappableExprsHandler {
VDLVal.getPointer(CGF));
CombinedInfo.Exprs.push_back(VD);
CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.Pointers.push_back(VarRVal.getScalarVal());
CombinedInfo.Sizes.push_back(llvm::ConstantInt::get(CGF.Int64Ty, 0));
}
Expand All @@ -8654,7 +8625,7 @@ class MappableExprsHandler {
OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF |
OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
continue;
llvm::Value *BasePtr = LambdaPointers.lookup(*BasePointers[I]);
llvm::Value *BasePtr = LambdaPointers.lookup(BasePointers[I]);
assert(BasePtr && "Unable to find base lambda address.");
int TgtIdx = -1;
for (unsigned J = I; J > 0; --J) {
Expand Down Expand Up @@ -8696,7 +8667,8 @@ class MappableExprsHandler {
// pass its value.
if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) {
CombinedInfo.Exprs.push_back(VD);
CombinedInfo.BasePointers.emplace_back(Arg, VD);
CombinedInfo.BasePointers.emplace_back(Arg);
CombinedInfo.DevicePtrDecls.emplace_back(VD);
CombinedInfo.Pointers.push_back(Arg);
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
Expand Down Expand Up @@ -8938,6 +8910,7 @@ class MappableExprsHandler {
if (CI.capturesThis()) {
CombinedInfo.Exprs.push_back(nullptr);
CombinedInfo.BasePointers.push_back(CV);
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.Pointers.push_back(CV);
const auto *PtrTy = cast<PointerType>(RI.getType().getTypePtr());
CombinedInfo.Sizes.push_back(
Expand All @@ -8950,6 +8923,7 @@ class MappableExprsHandler {
const VarDecl *VD = CI.getCapturedVar();
CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
CombinedInfo.BasePointers.push_back(CV);
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.Pointers.push_back(CV);
if (!RI.getType()->isAnyPointerType()) {
// We have to signal to the runtime captures passed by value that are
Expand Down Expand Up @@ -8981,6 +8955,7 @@ class MappableExprsHandler {
auto I = FirstPrivateDecls.find(VD);
CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
CombinedInfo.BasePointers.push_back(CV);
CombinedInfo.DevicePtrDecls.push_back(nullptr);
if (I != FirstPrivateDecls.end() && ElementType->isAnyPointerType()) {
Address PtrAddr = CGF.EmitLoadOfReference(CGF.MakeAddrLValue(
CV, ElementType, CGF.getContext().getDeclAlign(VD),
Expand Down Expand Up @@ -9266,7 +9241,7 @@ static void emitOffloadingArrays(
}

for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
llvm::Value *BPVal = *CombinedInfo.BasePointers[I];
llvm::Value *BPVal = CombinedInfo.BasePointers[I];
llvm::Value *BP = CGF.Builder.CreateConstInBoundsGEP2_32(
llvm::ArrayType::get(CGM.VoidPtrTy, Info.NumberOfPtrs),
Info.RTArgs.BasePointersArray, 0, I);
Expand All @@ -9277,8 +9252,7 @@ static void emitOffloadingArrays(
CGF.Builder.CreateStore(BPVal, BPAddr);

if (Info.requiresDevicePointerInfo())
if (const ValueDecl *DevVD =
CombinedInfo.BasePointers[I].getDevicePtrDecl())
if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I])
Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr);

llvm::Value *PVal = CombinedInfo.Pointers[I];
Expand Down Expand Up @@ -9592,7 +9566,7 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
// Fill up the runtime mapper handle for all components.
for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
llvm::Value *CurBaseArg = MapperCGF.Builder.CreateBitCast(
*Info.BasePointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
Info.BasePointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
llvm::Value *CurBeginArg = MapperCGF.Builder.CreateBitCast(
Info.Pointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
llvm::Value *CurSizeArg = Info.Sizes[I];
Expand Down Expand Up @@ -10028,6 +10002,7 @@ void CGOpenMPRuntime::emitTargetCall(
if (CI->capturesVariableArrayType()) {
CurInfo.Exprs.push_back(nullptr);
CurInfo.BasePointers.push_back(*CV);
CurInfo.DevicePtrDecls.push_back(nullptr);
CurInfo.Pointers.push_back(*CV);
CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));
Expand Down
43 changes: 43 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,49 @@ class OpenMPIRBuilder {
bool separateBeginEndCalls() { return SeparateBeginEndCalls; }
};

using MapValuesArrayTy = SmallVector<Value *, 4>;
using MapFlagsArrayTy = SmallVector<omp::OpenMPOffloadMappingFlags, 4>;
using MapNamesArrayTy = SmallVector<Constant *, 4>;
using MapDimArrayTy = SmallVector<uint64_t, 4>;
using MapNonContiguousArrayTy = SmallVector<MapValuesArrayTy, 4>;

/// This structure contains combined information generated for mappable
/// clauses, including base pointers, pointers, sizes, map types, user-defined
/// mappers, and non-contiguous information.
struct MapInfosTy {
struct StructNonContiguousInfo {
bool IsNonContiguous = false;
MapDimArrayTy Dims;
MapNonContiguousArrayTy Offsets;
MapNonContiguousArrayTy Counts;
MapNonContiguousArrayTy Strides;
};
MapValuesArrayTy BasePointers;
MapValuesArrayTy Pointers;
MapValuesArrayTy Sizes;
MapFlagsArrayTy Types;
MapNamesArrayTy Names;
StructNonContiguousInfo NonContigInfo;

/// Append arrays in \a CurInfo.
void append(MapInfosTy &CurInfo) {
BasePointers.append(CurInfo.BasePointers.begin(),
CurInfo.BasePointers.end());
Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end());
Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end());
Types.append(CurInfo.Types.begin(), CurInfo.Types.end());
Names.append(CurInfo.Names.begin(), CurInfo.Names.end());
NonContigInfo.Dims.append(CurInfo.NonContigInfo.Dims.begin(),
CurInfo.NonContigInfo.Dims.end());
NonContigInfo.Offsets.append(CurInfo.NonContigInfo.Offsets.begin(),
CurInfo.NonContigInfo.Offsets.end());
NonContigInfo.Counts.append(CurInfo.NonContigInfo.Counts.begin(),
CurInfo.NonContigInfo.Counts.end());
NonContigInfo.Strides.append(CurInfo.NonContigInfo.Strides.begin(),
CurInfo.NonContigInfo.Strides.end());
}
};

/// Emit the arguments to be passed to the runtime library based on the
/// arrays of base pointers, pointers, sizes, map types, and mappers. If
/// ForEndCall, emit map types to be passed for the end of the region instead
Expand Down

0 comments on commit 35309db

Please sign in to comment.