Skip to content

Commit

Permalink
[OpenMP][Fix] Properly inherit calling convention
Browse files Browse the repository at this point in the history
Previously in OpenMPOpt we did not correctly inherit the calling
convention of the callee when creating new OpenMP runtime calls. This
created issues when the calling convention was changed during
`GlobalOpt` but a new call was creating without the correct calling
convention. This lead to the call being replaced with a poison value in
`InstCombine` due to undefined behaviour and causing large portions of
the program to be incorrectly eliminated. This patch correctly inherits
the existing calling convention from the callee.

Reviewed By: tianshilei1992, jdoerfert

Differential Revision: https://reviews.llvm.org/D118059
  • Loading branch information
jhuber6 committed Jan 25, 2022
1 parent 16bff06 commit 06cfdd5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
57 changes: 40 additions & 17 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Expand Up @@ -417,6 +417,12 @@ struct OMPInformationCache : public InformationCache {
recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
}

// Helper function to inherit the calling convention of the function callee.
void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
CI->setCallingConv(Fn->getCallingConv());
}

/// Helper to initialize all runtime function information for those defined
/// in OpenMPKinds.def.
void initializeRuntimeFunctions() {
Expand Down Expand Up @@ -1531,6 +1537,7 @@ struct OpenMPOpt {

CallInst *IssueCallsite =
CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
RuntimeCall.eraseFromParent();

// Add "wait" runtime call declaration:
Expand All @@ -1543,7 +1550,9 @@ struct OpenMPOpt {
OffloadArray::DeviceIDArgNum), // device_id.
Handle // handle to wait on.
};
CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
CallInst *WaitCallsite = CallInst::Create(
WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);

return true;
}
Expand Down Expand Up @@ -3241,8 +3250,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
FunctionCallee HardwareTidFn =
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
Value *Tid =
CallInst *Tid =
OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
Tid->setDebugLoc(DL);
OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
OMPInfoCache.OMPBuilder.Builder
.CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
Expand All @@ -3255,14 +3266,18 @@ struct AAKernelInfoFunction : AAKernelInfo {
M, OMPRTL___kmpc_barrier_simple_spmd);
OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
->setDebugLoc(DL);
CallInst *Barrier =
OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
Barrier->setDebugLoc(DL);
OMPInfoCache.setCallingConvention(BarrierFn, Barrier);

// Second barrier ensures workers have read broadcast values.
if (HasBroadcastValues)
CallInst::Create(BarrierFn, {Ident, Tid}, "",
RegionBarrierBB->getTerminator())
->setDebugLoc(DL);
if (HasBroadcastValues) {
CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
RegionBarrierBB->getTerminator());
Barrier->setDebugLoc(DL);
OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
}
};

auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
Expand Down Expand Up @@ -3532,10 +3547,12 @@ struct AAKernelInfoFunction : AAKernelInfo {
FunctionCallee WarpSizeFn =
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_get_warp_size);
Instruction *BlockHwSize =
CallInst *BlockHwSize =
CallInst::Create(BlockHwSizeFn, "block.hw_size", InitBB);
OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
BlockHwSize->setDebugLoc(DLoc);
Instruction *WarpSize = CallInst::Create(WarpSizeFn, "warp.size", InitBB);
CallInst *WarpSize = CallInst::Create(WarpSizeFn, "warp.size", InitBB);
OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
WarpSize->setDebugLoc(DLoc);
Instruction *BlockSize =
BinaryOperator::CreateSub(BlockHwSize, WarpSize, "block.size", InitBB);
Expand Down Expand Up @@ -3575,8 +3592,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
FunctionCallee BarrierFn =
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_barrier_simple_generic);
CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
->setDebugLoc(DLoc);
CallInst *Barrier =
CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
Barrier->setDebugLoc(DLoc);

if (WorkFnAI->getType()->getPointerAddressSpace() !=
(unsigned int)AddressSpace::Generic) {
Expand All @@ -3592,8 +3611,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
FunctionCallee KernelParallelFn =
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_kernel_parallel);
Instruction *IsActiveWorker = CallInst::Create(
CallInst *IsActiveWorker = CallInst::Create(
KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
IsActiveWorker->setDebugLoc(DLoc);
Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
StateMachineBeginBB);
Expand Down Expand Up @@ -3673,10 +3693,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
StateMachineIfCascadeCurrentBB)
->setDebugLoc(DLoc);

CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_kernel_end_parallel),
{}, "", StateMachineEndParallelBB)
->setDebugLoc(DLoc);
FunctionCallee EndParallelFn =
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_kernel_end_parallel);
CallInst *EndParallel =
CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
EndParallel->setDebugLoc(DLoc);
BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
->setDebugLoc(DLoc);

Expand Down
6 changes: 4 additions & 2 deletions llvm/test/Transforms/OpenMP/spmdization.ll
Expand Up @@ -1430,7 +1430,7 @@ define internal void @__omp_outlined__6(i32* noalias %.global_tid., i32* noalias
; AMDGPU-NEXT: [[X_ON_STACK:%.*]] = bitcast i8* addrspacecast (i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* @x.1, i32 0, i32 0) to i8*) to i32*
; AMDGPU-NEXT: br label [[REGION_CHECK_TID:%.*]]
; AMDGPU: region.check.tid:
; AMDGPU-NEXT: [[TMP0:%.*]] = call i32 @__kmpc_get_hardware_thread_id_in_block()
; AMDGPU-NEXT: [[TMP0:%.*]] = call fastcc i32 @__kmpc_get_hardware_thread_id_in_block()
; AMDGPU-NEXT: [[TMP1:%.*]] = icmp eq i32 [[TMP0]], 0
; AMDGPU-NEXT: br i1 [[TMP1]], label [[REGION_GUARDED:%.*]], label [[REGION_BARRIER:%.*]]
; AMDGPU: region.guarded:
Expand Down Expand Up @@ -1466,7 +1466,7 @@ define internal void @__omp_outlined__6(i32* noalias %.global_tid., i32* noalias
; NVPTX-NEXT: [[X_ON_STACK:%.*]] = bitcast i8* addrspacecast (i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* @x1, i32 0, i32 0) to i8*) to i32*
; NVPTX-NEXT: br label [[REGION_CHECK_TID:%.*]]
; NVPTX: region.check.tid:
; NVPTX-NEXT: [[TMP0:%.*]] = call i32 @__kmpc_get_hardware_thread_id_in_block()
; NVPTX-NEXT: [[TMP0:%.*]] = call fastcc i32 @__kmpc_get_hardware_thread_id_in_block()
; NVPTX-NEXT: [[TMP1:%.*]] = icmp eq i32 [[TMP0]], 0
; NVPTX-NEXT: br i1 [[TMP1]], label [[REGION_GUARDED:%.*]], label [[REGION_BARRIER:%.*]]
; NVPTX: region.guarded:
Expand Down Expand Up @@ -2328,6 +2328,8 @@ entry:
ret void
}

declare fastcc i32 @__kmpc_get_hardware_thread_id_in_block();

attributes #0 = { alwaysinline convergent norecurse nounwind }
attributes #1 = { argmemonly mustprogress nofree nosync nounwind willreturn }
attributes #2 = { convergent }
Expand Down

0 comments on commit 06cfdd5

Please sign in to comment.