diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 456c94090a149..00e3f1b8c6f72 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -2041,7 +2041,8 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() { UndefValue::get(Int8Ty), F->getName() + ".ID"); for (Use *U : ToBeReplacedStateMachineUses) - U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); + U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast( + ID, U->get()->getType())); ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; @@ -3455,10 +3456,14 @@ struct AAKernelInfoFunction : AAKernelInfo { IsWorker->setDebugLoc(DLoc); BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB); + Module &M = *Kernel->getParent(); + // Create local storage for the work function pointer. + const DataLayout &DL = M.getDataLayout(); Type *VoidPtrTy = Type::getInt8PtrTy(Ctx); - AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr", - &Kernel->getEntryBlock().front()); + Instruction *WorkFnAI = + new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr, + "worker.work_fn.addr", &Kernel->getEntryBlock().front()); WorkFnAI->setDebugLoc(DLoc); auto &OMPInfoCache = static_cast(A.getInfoCache()); @@ -3471,13 +3476,23 @@ struct AAKernelInfoFunction : AAKernelInfo { Value *Ident = KernelInitCB->getArgOperand(0); Value *GTid = KernelInitCB; - Module &M = *Kernel->getParent(); FunctionCallee BarrierFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_barrier_simple_spmd); CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB) ->setDebugLoc(DLoc); + if (WorkFnAI->getType()->getPointerAddressSpace() != + (unsigned int)AddressSpace::Generic) { + WorkFnAI = new AddrSpaceCastInst( + WorkFnAI, + PointerType::getWithSamePointeeType( + cast(WorkFnAI->getType()), + (unsigned int)AddressSpace::Generic), + WorkFnAI->getName() + ".generic", StateMachineBeginBB); + WorkFnAI->setDebugLoc(DLoc); + } + FunctionCallee KernelParallelFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_kernel_parallel);