diff --git a/clang/lib/CodeGen/CGGPUBuiltin.cpp b/clang/lib/CodeGen/CGGPUBuiltin.cpp index afbebd070c0542..43192c587e262c 100644 --- a/clang/lib/CodeGen/CGGPUBuiltin.cpp +++ b/clang/lib/CodeGen/CGGPUBuiltin.cpp @@ -66,39 +66,22 @@ static llvm::Function *GetVprintfDeclaration(llvm::Module &M) { // // Note that by the time this function runs, E's args have already undergone the // standard C vararg promotion (short -> int, float -> double, etc.). -RValue -CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { - assert(getTarget().getTriple().isNVPTX()); - assert(E->getBuiltinCallee() == Builtin::BIprintf); - assert(E->getNumArgs() >= 1); // printf always has at least one arg. - - const llvm::DataLayout &DL = CGM.getDataLayout(); - llvm::LLVMContext &Ctx = CGM.getLLVMContext(); - - CallArgList Args; - EmitCallArgs(Args, - E->getDirectCallee()->getType()->getAs(), - E->arguments(), E->getDirectCallee(), - /* ParamsToSkip = */ 0); - // We don't know how to emit non-scalar varargs. - if (llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) { - return !A.getRValue(*this).isScalar(); - })) { - CGM.ErrorUnsupported(E, "non-scalar arg to printf"); - return RValue::get(llvm::ConstantInt::get(IntTy, 0)); - } +namespace { +llvm::Value *packArgsIntoNVPTXFormatBuffer(CodeGenFunction *CGF, + const CallArgList &Args) { + const llvm::DataLayout &DL = CGF->CGM.getDataLayout(); + llvm::LLVMContext &Ctx = CGF->CGM.getLLVMContext(); + CGBuilderTy &Builder = CGF->Builder; // Construct and fill the args buffer that we'll pass to vprintf. - llvm::Value *BufferPtr; if (Args.size() <= 1) { // If there are no args, pass a null pointer to vprintf. - BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx)); + return llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx)); } else { llvm::SmallVector ArgTypes; for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) - ArgTypes.push_back(Args[I].getRValue(*this).getScalarVal()->getType()); + ArgTypes.push_back(Args[I].getRValue(*CGF).getScalarVal()->getType()); // Using llvm::StructType is correct only because printf doesn't accept // aggregates. If we had to handle aggregates here, we'd have to manually @@ -106,15 +89,40 @@ CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E, // that the alignment of the llvm type was the same as the alignment of the // clang type. llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args"); - llvm::Value *Alloca = CreateTempAlloca(AllocaTy); + llvm::Value *Alloca = CGF->CreateTempAlloca(AllocaTy); for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) { llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1); - llvm::Value *Arg = Args[I].getRValue(*this).getScalarVal(); + llvm::Value *Arg = Args[I].getRValue(*CGF).getScalarVal(); Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlign(Arg->getType())); } - BufferPtr = Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx)); + return Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx)); } +} +} // namespace + +RValue +CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E, + ReturnValueSlot ReturnValue) { + assert(getTarget().getTriple().isNVPTX()); + assert(E->getBuiltinCallee() == Builtin::BIprintf); + assert(E->getNumArgs() >= 1); // printf always has at least one arg. + + CallArgList Args; + EmitCallArgs(Args, + E->getDirectCallee()->getType()->getAs(), + E->arguments(), E->getDirectCallee(), + /* ParamsToSkip = */ 0); + + // We don't know how to emit non-scalar varargs. + if (llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) { + return !A.getRValue(*this).isScalar(); + })) { + CGM.ErrorUnsupported(E, "non-scalar arg to printf"); + return RValue::get(llvm::ConstantInt::get(IntTy, 0)); + } + + llvm::Value *BufferPtr = packArgsIntoNVPTXFormatBuffer(this, Args); // Invoke vprintf and return. llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule());