diff --git a/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp b/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp index 4bbfba777e185..3e1dd874216cc 100644 --- a/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp +++ b/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp @@ -457,45 +457,41 @@ void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc, IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M); DtorFunc->setSection(".text.startup"); + auto *PtrTy = PointerType::getUnqual(C); + // Get the __cudaRegisterFatBinary function declaration. - auto *RegFatTy = FunctionType::get(PointerType::getUnqual(C)->getPointerTo(), - PointerType::getUnqual(C), - /*isVarArg*/ false); + auto *RegFatTy = FunctionType::get(PtrTy, PtrTy, /*isVarArg=*/false); FunctionCallee RegFatbin = M.getOrInsertFunction( IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy); // Get the __cudaRegisterFatBinaryEnd function declaration. - auto *RegFatEndTy = FunctionType::get( - Type::getVoidTy(C), PointerType::getUnqual(C)->getPointerTo(), - /*isVarArg*/ false); + auto *RegFatEndTy = + FunctionType::get(Type::getVoidTy(C), PtrTy, /*isVarArg=*/false); FunctionCallee RegFatbinEnd = M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy); // Get the __cudaUnregisterFatBinary function declaration. - auto *UnregFatTy = FunctionType::get( - Type::getVoidTy(C), PointerType::getUnqual(C)->getPointerTo(), - /*isVarArg*/ false); + auto *UnregFatTy = + FunctionType::get(Type::getVoidTy(C), PtrTy, /*isVarArg=*/false); FunctionCallee UnregFatbin = M.getOrInsertFunction( IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary", UnregFatTy); auto *AtExitTy = - FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(), - /*isVarArg*/ false); + FunctionType::get(Type::getInt32Ty(C), PtrTy, /*isVarArg=*/false); FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy); auto *BinaryHandleGlobal = new llvm::GlobalVariable( - M, PointerType::getUnqual(C)->getPointerTo(), false, - llvm::GlobalValue::InternalLinkage, - llvm::ConstantPointerNull::get(PointerType::getUnqual(C)->getPointerTo()), + M, PtrTy, false, llvm::GlobalValue::InternalLinkage, + llvm::ConstantPointerNull::get(PtrTy), IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle"); // Create the constructor to register this image with the runtime. IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc)); CallInst *Handle = CtorBuilder.CreateCall( - RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast( - FatbinDesc, PointerType::getUnqual(C))); + RegFatbin, + ConstantExpr::getPointerBitCastOrAddrSpaceCast(FatbinDesc, PtrTy)); CtorBuilder.CreateAlignedStore( Handle, BinaryHandleGlobal, - Align(M.getDataLayout().getPointerTypeSize(PointerType::getUnqual(C)))); + Align(M.getDataLayout().getPointerTypeSize(PtrTy))); CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle); if (!IsHIP) CtorBuilder.CreateCall(RegFatbinEnd, Handle); @@ -507,8 +503,8 @@ void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc, // `atexit()` intead. IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc)); LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad( - PointerType::getUnqual(C)->getPointerTo(), BinaryHandleGlobal, - Align(M.getDataLayout().getPointerTypeSize(PointerType::getUnqual(C)))); + PtrTy, BinaryHandleGlobal, + Align(M.getDataLayout().getPointerTypeSize(PtrTy))); DtorBuilder.CreateCall(UnregFatbin, BinaryHandle); DtorBuilder.CreateRetVoid();