diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp index 20f32ffeba3bf..554e66988f090 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -7,8 +7,11 @@ //===----------------------------------------------------------------------===// // // This pass modifies function signatures containing aggregate arguments -// and/or return value. Also it substitutes some llvm intrinsic calls by -// function calls, generating these functions as the translator does. +// and/or return value before IRTranslator. Information about the original +// signatures is stored in metadata. It is used during call lowering to +// restore correct SPIR-V types of function arguments and return values. +// This pass also substitutes some llvm intrinsic calls with calls to newly +// generated functions (as the Khronos LLVM/SPIR-V Translator does). // // NOTE: this pass is a module-level one due to the necessity to modify // GVs/functions. @@ -33,7 +36,8 @@ void initializeSPIRVPrepareFunctionsPass(PassRegistry &); namespace { class SPIRVPrepareFunctions : public ModulePass { - Function *processFunctionSignature(Function *F); + bool substituteIntrinsicCalls(Function *F); + Function *removeAggregateTypesFromSignature(Function *F); public: static char ID; @@ -57,68 +61,6 @@ char SPIRVPrepareFunctions::ID = 0; INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions", "SPIRV prepare functions", false, false) -Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) { - IRBuilder<> B(F->getContext()); - - bool IsRetAggr = F->getReturnType()->isAggregateType(); - bool HasAggrArg = - std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { - return Arg.getType()->isAggregateType(); - }); - bool DoClone = IsRetAggr || HasAggrArg; - if (!DoClone) - return F; - SmallVector, 4> ChangedTypes; - Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); - if (IsRetAggr) - ChangedTypes.push_back(std::pair(-1, F->getReturnType())); - SmallVector ArgTypes; - for (const auto &Arg : F->args()) { - if (Arg.getType()->isAggregateType()) { - ArgTypes.push_back(B.getInt32Ty()); - ChangedTypes.push_back( - std::pair(Arg.getArgNo(), Arg.getType())); - } else - ArgTypes.push_back(Arg.getType()); - } - FunctionType *NewFTy = - FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); - Function *NewF = - Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); - - ValueToValueMapTy VMap; - auto NewFArgIt = NewF->arg_begin(); - for (auto &Arg : F->args()) { - StringRef ArgName = Arg.getName(); - NewFArgIt->setName(ArgName); - VMap[&Arg] = &(*NewFArgIt++); - } - SmallVector Returns; - - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns); - NewF->takeName(F); - - NamedMDNode *FuncMD = - F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); - SmallVector MDArgs; - MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); - for (auto &ChangedTyP : ChangedTypes) - MDArgs.push_back(MDNode::get( - B.getContext(), - {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), - ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); - MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); - FuncMD->addOperand(ThisFuncMD); - - for (auto *U : make_early_inc_range(F->users())) { - if (auto *CI = dyn_cast(U)) - CI->mutateFunctionType(NewF->getFunctionType()); - U->replaceUsesOfWith(F, NewF); - } - return NewF; -} - std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { Function *IntrinsicFunc = II->getCalledFunction(); assert(IntrinsicFunc && "Missing function"); @@ -142,15 +84,16 @@ static Function *getOrCreateFunction(Module *M, Type *RetTy, return NewF; } -static void lowerIntrinsicToFunction(Module *M, IntrinsicInst *Intrinsic) { +static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) { // For @llvm.memset.* intrinsic cases with constant value and length arguments // are emulated via "storing" a constant array to the destination. For other // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the // intrinsic to a loop via expandMemSetAsLoop(). if (auto *MSI = dyn_cast(Intrinsic)) if (isa(MSI->getValue()) && isa(MSI->getLength())) - return; // It is handled later using OpCopyMemorySized. + return false; // It is handled later using OpCopyMemorySized. + Module *M = Intrinsic->getModule(); std::string FuncName = lowerLLVMIntrinsicName(Intrinsic); if (Intrinsic->isVolatile()) FuncName += ".volatile"; @@ -158,7 +101,7 @@ static void lowerIntrinsicToFunction(Module *M, IntrinsicInst *Intrinsic) { Function *F = M->getFunction(FuncName); if (F) { Intrinsic->setCalledFunction(F); - return; + return true; } // TODO copy arguments attributes: nocapture writeonly. FunctionCallee FC = @@ -202,14 +145,15 @@ static void lowerIntrinsicToFunction(Module *M, IntrinsicInst *Intrinsic) { default: break; } - return; + return true; } -static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) { +static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) { // Get a separate function - otherwise, we'd have to rework the CFG of the // current one. Then simply replace the intrinsic uses with a call to the new // function. // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) + Module *M = FSHIntrinsic->getModule(); FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); Type *FSHRetTy = FSHFuncTy->getReturnType(); const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic); @@ -265,12 +209,13 @@ static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) { FSHIntrinsic->setCalledFunction(FSHFunc); } -static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) { +static void buildUMulWithOverflowFunc(Function *UMulFunc) { // The function body is already created. if (!UMulFunc->empty()) return; - BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc); + BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(), + "entry", UMulFunc); IRBuilder<> IRB(EntryBB); // Build the actual unsigned multiplication logic with the overflow // indication. Do unsigned multiplication Mul = A * B. Then check @@ -288,65 +233,132 @@ static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) { IRB.CreateRet(Res); } -static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) { +static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) { // Get a separate function - otherwise, we'd have to rework the CFG of the // current one. Then simply replace the intrinsic uses with a call to the new // function. + Module *M = UMulIntrinsic->getModule(); FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); Type *FSHLRetTy = UMulFuncTy->getReturnType(); const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic); Function *UMulFunc = getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName); - buildUMulWithOverflowFunc(M, UMulFunc); + buildUMulWithOverflowFunc(UMulFunc); UMulIntrinsic->setCalledFunction(UMulFunc); } -static void substituteIntrinsicCalls(Module *M, Function *F) { +// Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics +// or calls to proper generated functions. Returns True if F was modified. +bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { + bool Changed = false; for (BasicBlock &BB : *F) { for (Instruction &I : BB) { auto Call = dyn_cast(&I); if (!Call) continue; - Call->setTailCall(false); Function *CF = Call->getCalledFunction(); if (!CF || !CF->isIntrinsic()) continue; auto *II = cast(Call); if (II->getIntrinsicID() == Intrinsic::memset || II->getIntrinsicID() == Intrinsic::bswap) - lowerIntrinsicToFunction(M, II); + Changed |= lowerIntrinsicToFunction(II); else if (II->getIntrinsicID() == Intrinsic::fshl || - II->getIntrinsicID() == Intrinsic::fshr) - lowerFunnelShifts(M, II); - else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) - lowerUMulWithOverflow(M, II); + II->getIntrinsicID() == Intrinsic::fshr) { + lowerFunnelShifts(II); + Changed = true; + } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) { + lowerUMulWithOverflow(II); + Changed = true; + } } } + return Changed; +} + +// Returns F if aggregate argument/return types are not present or cloned F +// function with the types replaced by i32 types. The change in types is +// noted in 'spv.cloned_funcs' metadata for later restoration. +Function * +SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { + IRBuilder<> B(F->getContext()); + + bool IsRetAggr = F->getReturnType()->isAggregateType(); + bool HasAggrArg = + std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { + return Arg.getType()->isAggregateType(); + }); + bool DoClone = IsRetAggr || HasAggrArg; + if (!DoClone) + return F; + SmallVector, 4> ChangedTypes; + Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); + if (IsRetAggr) + ChangedTypes.push_back(std::pair(-1, F->getReturnType())); + SmallVector ArgTypes; + for (const auto &Arg : F->args()) { + if (Arg.getType()->isAggregateType()) { + ArgTypes.push_back(B.getInt32Ty()); + ChangedTypes.push_back( + std::pair(Arg.getArgNo(), Arg.getType())); + } else + ArgTypes.push_back(Arg.getType()); + } + FunctionType *NewFTy = + FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); + Function *NewF = + Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); + + ValueToValueMapTy VMap; + auto NewFArgIt = NewF->arg_begin(); + for (auto &Arg : F->args()) { + StringRef ArgName = Arg.getName(); + NewFArgIt->setName(ArgName); + VMap[&Arg] = &(*NewFArgIt++); + } + SmallVector Returns; + + CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, + Returns); + NewF->takeName(F); + + NamedMDNode *FuncMD = + F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); + SmallVector MDArgs; + MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); + for (auto &ChangedTyP : ChangedTypes) + MDArgs.push_back(MDNode::get( + B.getContext(), + {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), + ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); + MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); + FuncMD->addOperand(ThisFuncMD); + + for (auto *U : make_early_inc_range(F->users())) { + if (auto *CI = dyn_cast(U)) + CI->mutateFunctionType(NewF->getFunctionType()); + U->replaceUsesOfWith(F, NewF); + } + return NewF; } bool SPIRVPrepareFunctions::runOnModule(Module &M) { + bool Changed = false; for (Function &F : M) - substituteIntrinsicCalls(&M, &F); + Changed |= substituteIntrinsicCalls(&F); std::vector FuncsWorklist; - bool Changed = false; for (auto &F : M) FuncsWorklist.push_back(&F); - for (auto *Func : FuncsWorklist) { - Function *F = processFunctionSignature(Func); - - bool CreatedNewF = F != Func; + for (auto *F : FuncsWorklist) { + Function *NewF = removeAggregateTypesFromSignature(F); - if (Func->isDeclaration()) { - Changed |= CreatedNewF; - continue; + if (NewF != F) { + F->eraseFromParent(); + Changed = true; } - - if (CreatedNewF) - Func->eraseFromParent(); } - return Changed; }