From da84741bf676cc5154f0a0577b5f267365ed3e43 Mon Sep 17 00:00:00 2001 From: Matheus Izvekov Date: Wed, 8 Oct 2025 22:10:46 -0300 Subject: [PATCH] [clang] fix transform for constant template parameter type subst node This fixes the transform to use the correct parameter type for an AssociatedDecl which has been fully specialized. Instead of using the type for the parameter of the specialized template, this uses the type of the argument it has been specialized with. This fixes a regression reported here: https://github.com/llvm/llvm-project/pull/161029#issuecomment-3375478990 Since this regression was never released, there are no release notes. --- clang/include/clang/AST/DeclTemplate.h | 7 ++- clang/lib/AST/DeclTemplate.cpp | 80 +++++++++++++----------- clang/lib/AST/ExprCXX.cpp | 4 +- clang/lib/AST/TemplateName.cpp | 10 ++- clang/lib/AST/Type.cpp | 14 ++--- clang/lib/Sema/TreeTransform.h | 17 +++-- clang/test/CodeGenCXX/template-cxx20.cpp | 24 +++++++ 7 files changed, 93 insertions(+), 63 deletions(-) create mode 100644 clang/test/CodeGenCXX/template-cxx20.cpp diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h index bba72365089f9..058316fa385cf 100644 --- a/clang/include/clang/AST/DeclTemplate.h +++ b/clang/include/clang/AST/DeclTemplate.h @@ -3395,9 +3395,10 @@ inline UnsignedOrNone getExpandedPackSize(const NamedDecl *Param) { return std::nullopt; } -/// Internal helper used by Subst* nodes to retrieve the parameter list -/// for their AssociatedDecl. -TemplateParameterList *getReplacedTemplateParameterList(const Decl *D); +/// Internal helper used by Subst* nodes to retrieve a parameter from the +/// AssociatedDecl, and the template argument substituted into it, if any. +std::tuple +getReplacedTemplateParameter(Decl *D, unsigned Index); } // namespace clang diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp index b6bb6117d42af..f22f1c9470ed0 100644 --- a/clang/lib/AST/DeclTemplate.cpp +++ b/clang/lib/AST/DeclTemplate.cpp @@ -1653,57 +1653,65 @@ void TemplateParamObjectDecl::printAsInit(llvm::raw_ostream &OS, getValue().printPretty(OS, Policy, getType(), &getASTContext()); } -TemplateParameterList *clang::getReplacedTemplateParameterList(const Decl *D) { +std::tuple +clang::getReplacedTemplateParameter(Decl *D, unsigned Index) { switch (D->getKind()) { - case Decl::Kind::CXXRecord: - return cast(D) - ->getDescribedTemplate() - ->getTemplateParameters(); + case Decl::Kind::BuiltinTemplate: case Decl::Kind::ClassTemplate: - return cast(D)->getTemplateParameters(); + case Decl::Kind::Concept: + case Decl::Kind::FunctionTemplate: + case Decl::Kind::TemplateTemplateParm: + case Decl::Kind::TypeAliasTemplate: + case Decl::Kind::VarTemplate: + return {cast(D)->getTemplateParameters()->getParam(Index), + {}}; case Decl::Kind::ClassTemplateSpecialization: { const auto *CTSD = cast(D); auto P = CTSD->getSpecializedTemplateOrPartial(); + TemplateParameterList *TPL; if (const auto *CTPSD = dyn_cast(P)) - return CTPSD->getTemplateParameters(); - return cast(P)->getTemplateParameters(); + TPL = CTPSD->getTemplateParameters(); + else + TPL = cast(P)->getTemplateParameters(); + return {TPL->getParam(Index), CTSD->getTemplateArgs()[Index]}; + } + case Decl::Kind::VarTemplateSpecialization: { + const auto *VTSD = cast(D); + auto P = VTSD->getSpecializedTemplateOrPartial(); + TemplateParameterList *TPL; + if (const auto *VTPSD = dyn_cast(P)) + TPL = VTPSD->getTemplateParameters(); + else + TPL = cast(P)->getTemplateParameters(); + return {TPL->getParam(Index), VTSD->getTemplateArgs()[Index]}; } case Decl::Kind::ClassTemplatePartialSpecialization: - return cast(D) - ->getTemplateParameters(); - case Decl::Kind::TypeAliasTemplate: - return cast(D)->getTemplateParameters(); - case Decl::Kind::BuiltinTemplate: - return cast(D)->getTemplateParameters(); + return {cast(D) + ->getTemplateParameters() + ->getParam(Index), + {}}; + case Decl::Kind::VarTemplatePartialSpecialization: + return {cast(D) + ->getTemplateParameters() + ->getParam(Index), + {}}; + // This is used as the AssociatedDecl for placeholder type deduction. + case Decl::TemplateTypeParm: + return {cast(D), {}}; + // FIXME: Always use the template decl as the AssociatedDecl. + case Decl::Kind::CXXRecord: + return getReplacedTemplateParameter( + cast(D)->getDescribedClassTemplate(), Index); case Decl::Kind::CXXDeductionGuide: case Decl::Kind::CXXConversion: case Decl::Kind::CXXConstructor: case Decl::Kind::CXXDestructor: case Decl::Kind::CXXMethod: case Decl::Kind::Function: - return cast(D) - ->getTemplateSpecializationInfo() - ->getTemplate() - ->getTemplateParameters(); - case Decl::Kind::FunctionTemplate: - return cast(D)->getTemplateParameters(); - case Decl::Kind::VarTemplate: - return cast(D)->getTemplateParameters(); - case Decl::Kind::VarTemplateSpecialization: { - const auto *VTSD = cast(D); - auto P = VTSD->getSpecializedTemplateOrPartial(); - if (const auto *VTPSD = dyn_cast(P)) - return VTPSD->getTemplateParameters(); - return cast(P)->getTemplateParameters(); - } - case Decl::Kind::VarTemplatePartialSpecialization: - return cast(D) - ->getTemplateParameters(); - case Decl::Kind::TemplateTemplateParm: - return cast(D)->getTemplateParameters(); - case Decl::Kind::Concept: - return cast(D)->getTemplateParameters(); + return getReplacedTemplateParameter( + cast(D)->getTemplateSpecializationInfo()->getTemplate(), + Index); default: llvm_unreachable("Unhandled templated declaration kind"); } diff --git a/clang/lib/AST/ExprCXX.cpp b/clang/lib/AST/ExprCXX.cpp index 95de6a82a5270..c7f0ff040194d 100644 --- a/clang/lib/AST/ExprCXX.cpp +++ b/clang/lib/AST/ExprCXX.cpp @@ -1727,7 +1727,7 @@ SizeOfPackExpr *SizeOfPackExpr::CreateDeserialized(ASTContext &Context, NonTypeTemplateParmDecl *SubstNonTypeTemplateParmExpr::getParameter() const { return cast( - getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]); + std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index))); } PackIndexingExpr *PackIndexingExpr::Create( @@ -1793,7 +1793,7 @@ SubstNonTypeTemplateParmPackExpr::SubstNonTypeTemplateParmPackExpr( NonTypeTemplateParmDecl * SubstNonTypeTemplateParmPackExpr::getParameterPack() const { return cast( - getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]); + std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index))); } TemplateArgument SubstNonTypeTemplateParmPackExpr::getArgumentPack() const { diff --git a/clang/lib/AST/TemplateName.cpp b/clang/lib/AST/TemplateName.cpp index 2b8044e4188cd..797a354c5d0fa 100644 --- a/clang/lib/AST/TemplateName.cpp +++ b/clang/lib/AST/TemplateName.cpp @@ -64,16 +64,14 @@ SubstTemplateTemplateParmPackStorage::getArgumentPack() const { TemplateTemplateParmDecl * SubstTemplateTemplateParmPackStorage::getParameterPack() const { - return cast( - getReplacedTemplateParameterList(getAssociatedDecl()) - ->asArray()[Bits.Index]); + return cast(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index))); } TemplateTemplateParmDecl * SubstTemplateTemplateParmStorage::getParameter() const { - return cast( - getReplacedTemplateParameterList(getAssociatedDecl()) - ->asArray()[Bits.Index]); + return cast(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index))); } void SubstTemplateTemplateParmStorage::Profile(llvm::FoldingSetNodeID &ID) { diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index 9794314a98f81..ee7a68ee8ba7e 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -4436,14 +4436,6 @@ IdentifierInfo *TemplateTypeParmType::getIdentifier() const { return isCanonicalUnqualified() ? nullptr : getDecl()->getIdentifier(); } -static const TemplateTypeParmDecl *getReplacedParameter(Decl *D, - unsigned Index) { - if (const auto *TTP = dyn_cast(D)) - return TTP; - return cast( - getReplacedTemplateParameterList(D)->getParam(Index)); -} - SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, unsigned Index, @@ -4466,7 +4458,8 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement, const TemplateTypeParmDecl * SubstTemplateTypeParmType::getReplacedParameter() const { - return ::getReplacedParameter(getAssociatedDecl(), getIndex()); + return cast(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), getIndex()))); } void SubstTemplateTypeParmType::Profile(llvm::FoldingSetNodeID &ID, @@ -4532,7 +4525,8 @@ bool SubstTemplateTypeParmPackType::getFinal() const { const TemplateTypeParmDecl * SubstTemplateTypeParmPackType::getReplacedParameter() const { - return ::getReplacedParameter(getAssociatedDecl(), getIndex()); + return cast(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), getIndex()))); } IdentifierInfo *SubstTemplateTypeParmPackType::getIdentifier() const { diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 51b55b82f4208..940324bbc5e4d 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -16364,16 +16364,21 @@ ExprResult TreeTransform::TransformSubstNonTypeTemplateParmExpr( AssociatedDecl == E->getAssociatedDecl()) return E; + auto getParamAndType = [Index = E->getIndex()](Decl *AssociatedDecl) + -> std::tuple { + auto [PDecl, Arg] = getReplacedTemplateParameter(AssociatedDecl, Index); + auto *Param = cast(PDecl); + return {Param, Arg.isNull() ? Param->getType() + : Arg.getNonTypeTemplateArgumentType()}; + }; + // If the replacement expression did not change, and the parameter type // did not change, we can skip the semantic action because it would // produce the same result anyway. - auto *Param = cast( - getReplacedTemplateParameterList(AssociatedDecl) - ->asArray()[E->getIndex()]); - if (QualType ParamType = Param->getType(); - !SemaRef.Context.hasSameType(ParamType, E->getParameter()->getType()) || + if (auto [Param, ParamType] = getParamAndType(AssociatedDecl); + !SemaRef.Context.hasSameType( + ParamType, std::get<1>(getParamAndType(E->getAssociatedDecl()))) || Replacement.get() != OrigReplacement) { - // When transforming the replacement expression previously, all Sema // specific annotations, such as implicit casts, are discarded. Calling the // corresponding sema action is necessary to recover those. Otherwise, diff --git a/clang/test/CodeGenCXX/template-cxx20.cpp b/clang/test/CodeGenCXX/template-cxx20.cpp new file mode 100644 index 0000000000000..aeb1cc915145f --- /dev/null +++ b/clang/test/CodeGenCXX/template-cxx20.cpp @@ -0,0 +1,24 @@ +// RUN: %clang_cc1 %s -O0 -disable-llvm-passes -triple=x86_64 -std=c++20 -emit-llvm -o - | FileCheck %s + +namespace GH161029_regression1 { + template auto f(int) { _Fp{}(0); } + template void g() { + (..., f<_Fp>(_Js)); + } + enum E { k }; + template struct ElementAt; + template struct ElementAt<0, First> { + static int value; + }; + template struct TagSet { + template using Tag = ElementAt; + }; + template struct S { + void U() { (void)TagSet::template Tag<0>::value; } + }; + S> s; + void h() { + g void { s.U(); }), 0>(); + } + // CHECK: call void @_ZN20GH161029_regression11SINS_6TagSetINS_1EELS2_0EEEE1UEv +}