diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp index 4c2577126e48b..395b25207e1df 100644 --- a/clang/lib/CodeGen/CGCall.cpp +++ b/clang/lib/CodeGen/CGCall.cpp @@ -4520,7 +4520,7 @@ void CodeGenFunction::EmitCallArgs( (isa(AC.getDecl()) && isObjCMethodWithTypeParams(cast(AC.getDecl())))) && "Argument and parameter types don't match"); - EmitCallArg(Args, *Arg, ArgTypes[Idx]); + EmitCallArg(Args, *Arg, ArgTypes[Idx], AC); // In particular, we depend on it being the last arg in Args, and the // objectsize bits depend on there only being one arg if !LeftToRight. assert(InitialArgSize + 1 == Args.size() && @@ -4611,7 +4611,7 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const { } void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E, - QualType type) { + QualType type, const AbstractCallee& AC) { DisableDebugLocationUpdates Dis(*this, E); if (const ObjCIndirectCopyRestoreExpr *CRE = dyn_cast(E)) { @@ -4627,6 +4627,26 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E, return args.add(EmitReferenceBindingToExpr(E), type); } + auto IsSingleParameterCopyConstructor = [&]() { + if(1 != AC.getNumParams()) return false; + if (const CXXRecordDecl* SubRecordDecl = type->getAsCXXRecordDecl()) { + if (const CXXConstructorDecl* ConstructorDecl = dyn_cast(AC.getDecl())) { + if(const CXXRecordDecl* BaseRecordDecl = dyn_cast(ConstructorDecl->getParent())) { + if(SubRecordDecl->isDerivedFrom(BaseRecordDecl)) { + return true; + } + } + } + } + return false; + }; + if(IsSingleParameterCopyConstructor()) { + AggValueSlot Slot = args.isUsingInAlloca() + ? createPlaceholderSlot(*this, type) : CreateAggTemp(type, "agg.tmp"); + RValue RV = Slot.asRValue(); + return args.add(RV, type); + } + bool HasAggregateEvalKind = hasAggregateEvaluationKind(type); // In the Microsoft C++ ABI, aggregate arguments are destructed by the callee. diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 618e78809db40..f2305a9307396 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4746,7 +4746,7 @@ class CodeGenFunction : public CodeGenTypeCache { AbstractCallee AC, unsigned ParmNum); /// EmitCallArg - Emit a single call argument. - void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType); + void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType, const AbstractCallee& AC); /// EmitDelegateCallArg - We are performing a delegate call; that /// is, the current function is delegating to another one. Produce diff --git a/clang/unittests/CodeGen/CMakeLists.txt b/clang/unittests/CodeGen/CMakeLists.txt index a437f441568f2..8870237c85539 100644 --- a/clang/unittests/CodeGen/CMakeLists.txt +++ b/clang/unittests/CodeGen/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_unittest(ClangCodeGenTests CodeGenExternalTest.cpp TBAAMetadataTest.cpp CheckTargetFeaturesTest.cpp + TemplateInstantiationTest.cpp ) clang_target_link_libraries(ClangCodeGenTests diff --git a/llvm/TemplateInstantiationTest.cpp b/llvm/TemplateInstantiationTest.cpp new file mode 100644 index 0000000000000..fbb7b94e6f5aa --- /dev/null +++ b/llvm/TemplateInstantiationTest.cpp @@ -0,0 +1,216 @@ +//===- unittests/CodeGen/TemplateInstantiationTest.cpp - template instantiation test -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestCompiler.h" + +#include "clang/AST/ASTConsumer.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/GlobalDecl.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Basic/TargetInfo.h" +#include "clang/CodeGen/CodeGenABITypes.h" +#include "clang/CodeGen/ModuleBuilder.h" +#include "clang/Frontend/CompilerInstance.h" +#include "clang/Lex/Preprocessor.h" +#include "clang/Parse/ParseAST.h" +#include "clang/Sema/Sema.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/TargetParser/Host.h" +#include "llvm/TargetParser/Triple.h" +#include "gtest/gtest.h" + +#include "llvm/Analysis/CallGraph.h" +#include + +using namespace llvm; +using namespace clang; + +namespace { + +static const bool DebugThisTest = false; + +// forward declarations +struct TemplateInstantiationASTConsumer; +static void test_instantiation_fns(TemplateInstantiationASTConsumer *my); +static bool test_instantiation_fns_ran; + +// This forwards the calls to the Clang CodeGenerator +// so that we can test CodeGen functions while it is open. +// It accumulates toplevel decls in HandleTopLevelDecl and +// calls test_instantiation_fns() in HandleTranslationUnit +// after forwarding that function to the CodeGenerator. + +struct TemplateInstantiationASTConsumer : public ASTConsumer { + std::unique_ptr Builder; + std::vector toplevel_decls; + + TemplateInstantiationASTConsumer(std::unique_ptr Builder_in) + : ASTConsumer(), Builder(std::move(Builder_in)) + { + } + + ~TemplateInstantiationASTConsumer() { } + + void Initialize(ASTContext &Context) override; + void HandleCXXStaticMemberVarInstantiation(VarDecl *VD) override; + bool HandleTopLevelDecl(DeclGroupRef D) override; + void HandleInlineFunctionDefinition(FunctionDecl *D) override; + void HandleInterestingDecl(DeclGroupRef D) override; + void HandleTranslationUnit(ASTContext &Ctx) override; + void HandleTagDeclDefinition(TagDecl *D) override; + void HandleTagDeclRequiredDefinition(const TagDecl *D) override; + void HandleCXXImplicitFunctionInstantiation(FunctionDecl *D) override; + void HandleTopLevelDeclInObjCContainer(DeclGroupRef D) override; + void HandleImplicitImportDecl(ImportDecl *D) override; + void CompleteTentativeDefinition(VarDecl *D) override; + void AssignInheritanceModel(CXXRecordDecl *RD) override; + void HandleVTable(CXXRecordDecl *RD) override; + ASTMutationListener *GetASTMutationListener() override; + ASTDeserializationListener *GetASTDeserializationListener() override; + void PrintStats() override; + bool shouldSkipFunctionBody(Decl *D) override; +}; + +void TemplateInstantiationASTConsumer::Initialize(ASTContext &Context) { + Builder->Initialize(Context); +} + +bool TemplateInstantiationASTConsumer::HandleTopLevelDecl(DeclGroupRef DG) { + + for (DeclGroupRef::iterator I = DG.begin(), E = DG.end(); I != E; ++I) { + toplevel_decls.push_back(*I); + } + + return Builder->HandleTopLevelDecl(DG); +} + +void TemplateInstantiationASTConsumer::HandleInlineFunctionDefinition(FunctionDecl *D) { + Builder->HandleInlineFunctionDefinition(D); +} + +void TemplateInstantiationASTConsumer::HandleInterestingDecl(DeclGroupRef D) { + Builder->HandleInterestingDecl(D); +} + +void TemplateInstantiationASTConsumer::HandleTranslationUnit(ASTContext &Context) { + // HandleTranslationUnit can close the module + Builder->HandleTranslationUnit(Context); + test_instantiation_fns(this); +} + +void TemplateInstantiationASTConsumer::HandleTagDeclDefinition(TagDecl *D) { + Builder->HandleTagDeclDefinition(D); +} + +void TemplateInstantiationASTConsumer::HandleTagDeclRequiredDefinition(const TagDecl *D) { + Builder->HandleTagDeclRequiredDefinition(D); +} + +void TemplateInstantiationASTConsumer::HandleCXXImplicitFunctionInstantiation(FunctionDecl *D) { + Builder->HandleCXXImplicitFunctionInstantiation(D); +} + +void TemplateInstantiationASTConsumer::HandleTopLevelDeclInObjCContainer(DeclGroupRef D) { + Builder->HandleTopLevelDeclInObjCContainer(D); +} + +void TemplateInstantiationASTConsumer::HandleImplicitImportDecl(ImportDecl *D) { + Builder->HandleImplicitImportDecl(D); +} + +void TemplateInstantiationASTConsumer::CompleteTentativeDefinition(VarDecl *D) { + Builder->CompleteTentativeDefinition(D); +} + +void TemplateInstantiationASTConsumer::AssignInheritanceModel(CXXRecordDecl *RD) { + Builder->AssignInheritanceModel(RD); +} + +void TemplateInstantiationASTConsumer::HandleCXXStaticMemberVarInstantiation(VarDecl *VD) { + Builder->HandleCXXStaticMemberVarInstantiation(VD); +} + +void TemplateInstantiationASTConsumer::HandleVTable(CXXRecordDecl *RD) { + Builder->HandleVTable(RD); + } + +ASTMutationListener *TemplateInstantiationASTConsumer::GetASTMutationListener() { + return Builder->GetASTMutationListener(); +} + +ASTDeserializationListener *TemplateInstantiationASTConsumer::GetASTDeserializationListener() { + return Builder->GetASTDeserializationListener(); +} + +void TemplateInstantiationASTConsumer::PrintStats() { + Builder->PrintStats(); +} + +bool TemplateInstantiationASTConsumer::shouldSkipFunctionBody(Decl *D) { + return Builder->shouldSkipFunctionBody(D); +} + +const char TestProgram[] = "struct base { public : base() {} template base(T x) {} }; struct derived : public base { public: derived() {} derived(derived& that): base(that) {} }; int main() { derived d1; derived d2 = d1; return 0;}"; + +bool hasCycles(const Function *CurrentFunction, + std::unordered_set &VisitedFunctions, + std::unordered_set &RecursionStack, + const CallGraphNode* CurrentNode) { + VisitedFunctions.insert(CurrentFunction); + RecursionStack.insert(CurrentFunction); + for (CallGraphNode::const_iterator IT = CurrentNode->begin(), END = CurrentNode->end(); IT != END; ++IT) { + if (const Function *CalleeFunction = IT->second->getFunction()) { + if (RecursionStack.count(CalleeFunction)) { + return true; + } + if (VisitedFunctions.count(CalleeFunction) == 0 && hasCycles(CalleeFunction, VisitedFunctions, RecursionStack, IT->second)) { + return true; + } + } + } + RecursionStack.erase(CurrentFunction); + return false; +} + +static void test_instantiation_fns(TemplateInstantiationASTConsumer *InstantiationASTConsumer) { + test_instantiation_fns_ran = true; + llvm::Module* Mymodule = InstantiationASTConsumer->Builder->GetModule(); + CallGraph MyCallGraph(*Mymodule); + std::unordered_set VisitedFunctions; + std::unordered_set RecursionStack; + for (llvm::CallGraph::const_iterator IT = MyCallGraph.begin(), END = MyCallGraph.end(); + IT != END; ++IT) { + const Function* MyFunction = IT->first; + const CallGraphNode* MyCallGraphNode = IT->second.get(); + if (MyFunction && VisitedFunctions.count(MyFunction) == 0){ + if(hasCycles(MyFunction, VisitedFunctions, RecursionStack, MyCallGraphNode)) { + test_instantiation_fns_ran = false; + break; + } + } + } +} + +TEST(BaseConstructorTemplateInstantiationTest, BaseConstructorTemplateInstantiationTest) { + clang::LangOptions LO; + LO.CPlusPlus = 1; + TestCompiler Compiler(LO); + auto CustomASTConsumer + = std::make_unique(std::move(Compiler.CG)); + + Compiler.init(TestProgram, std::move(CustomASTConsumer)); + ParseAST(Compiler.compiler.getSema(), false, false); + + ASSERT_TRUE(test_instantiation_fns_ran); +} + +} // end anonymous namespace +