@@ -1128,9 +1128,10 @@ struct DeferredStore {
1128
1128
} // namespace
1129
1129
1130
1130
// / Check whether allocations for the given operation might potentially have to
1131
- // / be done in device shared memory. That means we're compiling for a offloading
1132
- // / target, the operation is an `omp::TargetOp` or nested inside of one and that
1133
- // / target region represents a Generic (non-SPMD) kernel.
1131
+ // / be done in device shared memory. That means we're compiling for an
1132
+ // / offloading target, the operation is neither an `omp::TargetOp` nor nested
1133
+ // / inside of one, or it is and that target region represents a Generic
1134
+ // / (non-SPMD) kernel.
1134
1135
// /
1135
1136
// / This represents a necessary but not sufficient set of conditions to use
1136
1137
// / device shared memory in place of regular allocas. For some variables, the
@@ -1146,7 +1147,7 @@ mightAllocInDeviceSharedMemory(Operation &op,
1146
1147
if (!targetOp)
1147
1148
targetOp = op.getParentOfType <omp::TargetOp>();
1148
1149
1149
- return targetOp &&
1150
+ return ! targetOp ||
1150
1151
targetOp.getKernelExecFlags (targetOp.getInnermostCapturedOmpOp ()) ==
1151
1152
omp::TargetExecMode::generic;
1152
1153
}
@@ -1160,18 +1161,36 @@ mightAllocInDeviceSharedMemory(Operation &op,
1160
1161
// / operation that owns the specified block argument.
1161
1162
static bool mustAllocPrivateVarInDeviceSharedMemory (BlockArgument value) {
1162
1163
Operation *parentOp = value.getOwner ()->getParentOp ();
1163
- auto targetOp = dyn_cast<omp::TargetOp>(parentOp);
1164
- if (!targetOp)
1165
- targetOp = parentOp->getParentOfType <omp::TargetOp>();
1166
- assert (targetOp && " expected a parent omp.target operation" );
1167
-
1164
+ auto moduleOp = parentOp->getParentOfType <ModuleOp>();
1168
1165
for (auto *user : value.getUsers ()) {
1169
1166
if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) {
1170
1167
if (llvm::is_contained (parallelOp.getReductionVars (), value))
1171
1168
return true ;
1172
1169
} else if (auto parallelOp = user->getParentOfType <omp::ParallelOp>()) {
1173
- if (parentOp->isProperAncestor (parallelOp))
1174
- return true ;
1170
+ if (parentOp->isProperAncestor (parallelOp)) {
1171
+ // If it is used directly inside of a parallel region, skip private
1172
+ // clause uses.
1173
+ bool isPrivateClauseUse = false ;
1174
+ if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(user)) {
1175
+ if (auto privateSyms = llvm::cast_or_null<ArrayAttr>(
1176
+ user->getAttr (" private_syms" ))) {
1177
+ for (auto [var, sym] :
1178
+ llvm::zip_equal (argIface.getPrivateVars (), privateSyms)) {
1179
+ if (var != value)
1180
+ continue ;
1181
+
1182
+ auto privateOp = cast<omp::PrivateClauseOp>(
1183
+ moduleOp.lookupSymbol (cast<SymbolRefAttr>(sym)));
1184
+ if (privateOp.getCopyRegion ().empty ()) {
1185
+ isPrivateClauseUse = true ;
1186
+ break ;
1187
+ }
1188
+ }
1189
+ }
1190
+ }
1191
+ if (!isPrivateClauseUse)
1192
+ return true ;
1193
+ }
1175
1194
}
1176
1195
}
1177
1196
@@ -1196,8 +1215,8 @@ allocReductionVars(T op, ArrayRef<BlockArgument> reductionArgs,
1196
1215
builder.SetInsertPoint (allocaIP.getBlock ()->getTerminator ());
1197
1216
1198
1217
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
1199
- bool useDeviceSharedMem =
1200
- isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1218
+ bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
1219
+ mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1201
1220
1202
1221
// delay creating stores until after all allocas
1203
1222
deferredStores.reserve (op.getNumReductionVars ());
@@ -1318,8 +1337,8 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1318
1337
return success ();
1319
1338
1320
1339
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
1321
- bool useDeviceSharedMem =
1322
- isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1340
+ bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
1341
+ mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1323
1342
1324
1343
llvm::BasicBlock *initBlock = splitBB (builder, true , " omp.reduction.init" );
1325
1344
auto allocaIP = llvm::IRBuilderBase::InsertPoint (
@@ -1540,8 +1559,8 @@ static LogicalResult createReductionsAndCleanup(
1540
1559
reductionRegions, privateReductionVariables, moduleTranslation, builder,
1541
1560
" omp.reduction.cleanup" );
1542
1561
1543
- bool useDeviceSharedMem =
1544
- isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1562
+ bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
1563
+ mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1545
1564
if (useDeviceSharedMem) {
1546
1565
for (auto [var, reductionDecl] :
1547
1566
llvm::zip_equal (privateReductionVariables, reductionDecls))
@@ -1721,7 +1740,7 @@ allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
1721
1740
1722
1741
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
1723
1742
bool mightUseDeviceSharedMem =
1724
- isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
1743
+ isa<omp::TargetOp, omp:: TeamsOp, omp::DistributeOp>(*op) &&
1725
1744
mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1726
1745
unsigned int allocaAS =
1727
1746
moduleTranslation.getLLVMModule ()->getDataLayout ().getAllocaAddrSpace ();
@@ -1839,7 +1858,7 @@ cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
1839
1858
1840
1859
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
1841
1860
bool mightUseDeviceSharedMem =
1842
- isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
1861
+ isa<omp::TargetOp, omp:: TeamsOp, omp::DistributeOp>(*op) &&
1843
1862
mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1844
1863
for (auto [privDecl, llvmPrivVar, blockArg] :
1845
1864
llvm::zip_equal (privateVarsInfo.privatizers , privateVarsInfo.llvmVars ,
@@ -5265,42 +5284,68 @@ handleDeclareTargetMapVar(MapInfoData &mapData,
5265
5284
// a store of the kernel argument into this allocated memory which
5266
5285
// will then be loaded from, ByCopy will use the allocated memory
5267
5286
// directly.
5268
- static llvm::IRBuilderBase::InsertPoint
5269
- createDeviceArgumentAccessor ( MapInfoData &mapData, llvm::Argument &arg,
5270
- llvm::Value *input, llvm::Value *&retVal,
5271
- llvm::IRBuilderBase &builder ,
5272
- llvm::OpenMPIRBuilder &ompBuilder ,
5273
- LLVM::ModuleTranslation &moduleTranslation ,
5274
- llvm::IRBuilderBase::InsertPoint allocaIP ,
5275
- llvm::IRBuilderBase::InsertPoint codeGenIP ) {
5287
+ static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor (
5288
+ omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
5289
+ llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder ,
5290
+ llvm::OpenMPIRBuilder &ompBuilder ,
5291
+ LLVM::ModuleTranslation &moduleTranslation ,
5292
+ llvm::IRBuilderBase::InsertPoint allocIP ,
5293
+ llvm::IRBuilderBase::InsertPoint codeGenIP ,
5294
+ llvm::ArrayRef<llvm:: IRBuilderBase::InsertPoint> deallocIPs ) {
5276
5295
assert (ompBuilder.Config .isTargetDevice () &&
5277
5296
" function only supported for target device codegen" );
5278
- builder.restoreIP (allocaIP );
5297
+ builder.restoreIP (allocIP );
5279
5298
5280
5299
omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5281
5300
LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator (
5282
5301
ompBuilder.M .getContext ());
5283
5302
unsigned alignmentValue = 0 ;
5303
+ BlockArgument mlirArg;
5284
5304
// Find the associated MapInfoData entry for the current input
5285
- for (size_t i = 0 ; i < mapData.MapClause .size (); ++i)
5305
+ for (size_t i = 0 ; i < mapData.MapClause .size (); ++i) {
5286
5306
if (mapData.OriginalValue [i] == input) {
5287
5307
auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause [i]);
5288
5308
capture = mapOp.getMapCaptureType ();
5289
5309
// Get information of alignment of mapped object
5290
5310
alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment (
5291
5311
mapOp.getVarType (), ompBuilder.M .getDataLayout ());
5312
+ // Get the corresponding target entry block argument
5313
+ mlirArg =
5314
+ cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getMapBlockArgs ()[i];
5292
5315
break ;
5293
5316
}
5317
+ }
5294
5318
5295
5319
unsigned int allocaAS = ompBuilder.M .getDataLayout ().getAllocaAddrSpace ();
5296
5320
unsigned int defaultAS =
5297
5321
ompBuilder.M .getDataLayout ().getProgramAddressSpace ();
5298
5322
5299
- // Create the alloca for the argument the current point.
5300
- llvm::Value *v = builder.CreateAlloca (arg.getType (), allocaAS);
5323
+ // Create the allocation for the argument.
5324
+ llvm::Value *v = nullptr ;
5325
+ if (mightAllocInDeviceSharedMemory (*targetOp, ompBuilder) &&
5326
+ mustAllocPrivateVarInDeviceSharedMemory (mlirArg)) {
5327
+ // Use the beginning of the codeGenIP rather than the usual allocation point
5328
+ // for shared memory allocations because otherwise these would be done prior
5329
+ // to the target initialization call. Also, the exit block (where the
5330
+ // deallocation is placed) is only executed if the initialization call
5331
+ // succeeds.
5332
+ builder.SetInsertPoint (codeGenIP.getBlock ()->getFirstInsertionPt ());
5333
+ v = ompBuilder.createOMPAllocShared (builder, arg.getType ());
5334
+
5335
+ // Create deallocations in all provided deallocation points and then restore
5336
+ // the insertion point to right after the new allocations.
5337
+ llvm::IRBuilderBase::InsertPointGuard guard (builder);
5338
+ for (auto deallocIP : deallocIPs) {
5339
+ builder.SetInsertPoint (deallocIP.getBlock (), deallocIP.getPoint ());
5340
+ ompBuilder.createOMPFreeShared (builder, v, arg.getType ());
5341
+ }
5342
+ } else {
5343
+ // Use the current point, which was previously set to allocIP.
5344
+ v = builder.CreateAlloca (arg.getType (), allocaAS);
5301
5345
5302
- if (allocaAS != defaultAS && arg.getType ()->isPointerTy ())
5303
- v = builder.CreateAddrSpaceCast (v, builder.getPtrTy (defaultAS));
5346
+ if (allocaAS != defaultAS && arg.getType ()->isPointerTy ())
5347
+ v = builder.CreateAddrSpaceCast (v, builder.getPtrTy (defaultAS));
5348
+ }
5304
5349
5305
5350
builder.CreateStore (&arg, v);
5306
5351
@@ -5890,8 +5935,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
5890
5935
};
5891
5936
5892
5937
auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
5893
- llvm::Value *&retVal, InsertPointTy allocaIP,
5894
- InsertPointTy codeGenIP)
5938
+ llvm::Value *&retVal, InsertPointTy allocIP,
5939
+ InsertPointTy codeGenIP,
5940
+ llvm::ArrayRef<InsertPointTy> deallocIPs)
5895
5941
-> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5896
5942
llvm::IRBuilderBase::InsertPointGuard guard (builder);
5897
5943
builder.SetCurrentDebugLocation (llvm::DebugLoc ());
@@ -5905,9 +5951,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
5905
5951
return codeGenIP;
5906
5952
}
5907
5953
5908
- return createDeviceArgumentAccessor (mapData, arg, input, retVal, builder ,
5909
- *ompBuilder, moduleTranslation,
5910
- allocaIP , codeGenIP);
5954
+ return createDeviceArgumentAccessor (targetOp, mapData, arg, input, retVal,
5955
+ builder, *ompBuilder, moduleTranslation,
5956
+ allocIP , codeGenIP, deallocIPs );
5911
5957
};
5912
5958
5913
5959
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
0 commit comments