Skip to content

Commit

Permalink
[clang] CTAD: build aggregate deduction guides for alias templates. (#…
Browse files Browse the repository at this point in the history
…85904)

Fixes #85767.

The aggregate deduction guides are handled in a separate code path. We
don't generate dedicated aggregate deduction guides for alias templates
(we just reuse the ones from the underlying template decl by accident).
The patch fixes this incorrect issue.

Note: there is a small refactoring change in this PR, where we move the
cache logic from `Sema::DeduceTemplateSpecializationFromInitializer` to
`Sema::DeclareImplicitDeductionGuideFromInitList`
  • Loading branch information
hokein committed Apr 5, 2024
1 parent 43ba568 commit 5e77dfe
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 70 deletions.
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -9713,7 +9713,8 @@ class Sema final {
/// not already done so.
void DeclareImplicitDeductionGuides(TemplateDecl *Template,
SourceLocation Loc);
FunctionTemplateDecl *DeclareImplicitDeductionGuideFromInitList(

FunctionTemplateDecl *DeclareAggregateDeductionGuideFromInitList(
TemplateDecl *Template, MutableArrayRef<QualType> ParamTypes,
SourceLocation Loc);

Expand Down
28 changes: 6 additions & 22 deletions clang/lib/Sema/SemaInit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10930,32 +10930,16 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
Context.getLValueReferenceType(ElementTypes[I].withConst());
}

llvm::FoldingSetNodeID ID;
ID.AddPointer(Template);
for (auto &T : ElementTypes)
T.getCanonicalType().Profile(ID);
unsigned Hash = ID.ComputeHash();
if (AggregateDeductionCandidates.count(Hash) == 0) {
if (FunctionTemplateDecl *TD =
DeclareImplicitDeductionGuideFromInitList(
Template, ElementTypes,
TSInfo->getTypeLoc().getEndLoc())) {
auto *GD = cast<CXXDeductionGuideDecl>(TD->getTemplatedDecl());
GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
AggregateDeductionCandidates[Hash] = GD;
addDeductionCandidate(TD, GD, DeclAccessPair::make(TD, AS_public),
OnlyListConstructors,
/*AllowAggregateDeductionCandidate=*/true);
}
} else {
CXXDeductionGuideDecl *GD = AggregateDeductionCandidates[Hash];
FunctionTemplateDecl *TD = GD->getDescribedFunctionTemplate();
assert(TD && "aggregate deduction candidate is function template");
if (FunctionTemplateDecl *TD =
DeclareAggregateDeductionGuideFromInitList(
LookupTemplateDecl, ElementTypes,
TSInfo->getTypeLoc().getEndLoc())) {
auto *GD = cast<CXXDeductionGuideDecl>(TD->getTemplatedDecl());
addDeductionCandidate(TD, GD, DeclAccessPair::make(TD, AS_public),
OnlyListConstructors,
/*AllowAggregateDeductionCandidate=*/true);
HasAnyDeductionGuide = true;
}
HasAnyDeductionGuide = true;
}
};

Expand Down
209 changes: 162 additions & 47 deletions clang/lib/Sema/SemaTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2754,23 +2754,42 @@ bool hasDeclaredDeductionGuides(DeclarationName Name, DeclContext *DC) {
return false;
}

// Build deduction guides for a type alias template.
void DeclareImplicitDeductionGuidesForTypeAlias(
Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate, SourceLocation Loc) {
if (AliasTemplate->isInvalidDecl())
return;
auto &Context = SemaRef.Context;
// FIXME: if there is an explicit deduction guide after the first use of the
// type alias usage, we will not cover this explicit deduction guide. fix this
// case.
if (hasDeclaredDeductionGuides(
Context.DeclarationNames.getCXXDeductionGuideName(AliasTemplate),
AliasTemplate->getDeclContext()))
return;
NamedDecl *transformTemplateParameter(Sema &SemaRef, DeclContext *DC,
NamedDecl *TemplateParam,
MultiLevelTemplateArgumentList &Args,
unsigned NewIndex) {
if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
return transformTemplateTypeParam(SemaRef, DC, TTP, Args, TTP->getDepth(),
NewIndex);
if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
return transformTemplateParam(SemaRef, DC, TTP, Args, NewIndex,
TTP->getDepth());
if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
return transformTemplateParam(SemaRef, DC, NTTP, Args, NewIndex,
NTTP->getDepth());
llvm_unreachable("Unhandled template parameter types");
}

Expr *transformRequireClause(Sema &SemaRef, FunctionTemplateDecl *FTD,
llvm::ArrayRef<TemplateArgument> TransformedArgs) {
Expr *RC = FTD->getTemplateParameters()->getRequiresClause();
if (!RC)
return nullptr;
MultiLevelTemplateArgumentList Args;
Args.setKind(TemplateSubstitutionKind::Rewrite);
Args.addOuterTemplateArguments(TransformedArgs);
ExprResult E = SemaRef.SubstExpr(RC, Args);
if (E.isInvalid())
return nullptr;
return E.getAs<Expr>();
}

std::pair<TemplateDecl *, llvm::ArrayRef<TemplateArgument>>
getRHSTemplateDeclAndArgs(Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate) {
// Unwrap the sugared ElaboratedType.
auto RhsType = AliasTemplate->getTemplatedDecl()
->getUnderlyingType()
.getSingleStepDesugaredType(Context);
.getSingleStepDesugaredType(SemaRef.Context);
TemplateDecl *Template = nullptr;
llvm::ArrayRef<TemplateArgument> AliasRhsTemplateArgs;
if (const auto *TST = RhsType->getAs<TemplateSpecializationType>()) {
Expand All @@ -2791,6 +2810,24 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
} else {
assert(false && "unhandled RHS type of the alias");
}
return {Template, AliasRhsTemplateArgs};
}

// Build deduction guides for a type alias template.
void DeclareImplicitDeductionGuidesForTypeAlias(
Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate, SourceLocation Loc) {
if (AliasTemplate->isInvalidDecl())
return;
auto &Context = SemaRef.Context;
// FIXME: if there is an explicit deduction guide after the first use of the
// type alias usage, we will not cover this explicit deduction guide. fix this
// case.
if (hasDeclaredDeductionGuides(
Context.DeclarationNames.getCXXDeductionGuideName(AliasTemplate),
AliasTemplate->getDeclContext()))
return;
auto [Template, AliasRhsTemplateArgs] =
getRHSTemplateDeclAndArgs(SemaRef, AliasTemplate);
if (!Template)
return;
DeclarationNameInfo NameInfo(
Expand All @@ -2803,6 +2840,13 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
FunctionTemplateDecl *F = dyn_cast<FunctionTemplateDecl>(G);
if (!F)
continue;
// The **aggregate** deduction guides are handled in a different code path
// (DeclareImplicitDeductionGuideFromInitList), which involves the tricky
// cache.
if (cast<CXXDeductionGuideDecl>(F->getTemplatedDecl())
->getDeductionCandidateKind() == DeductionCandidate::Aggregate)
continue;

auto RType = F->getTemplatedDecl()->getReturnType();
// The (trailing) return type of the deduction guide.
const TemplateSpecializationType *FReturnType =
Expand Down Expand Up @@ -2885,21 +2929,6 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
// parameters, used for building `TemplateArgsForBuildingFPrime`.
SmallVector<TemplateArgument, 16> TransformedDeducedAliasArgs(
AliasTemplate->getTemplateParameters()->size());
auto TransformTemplateParameter =
[&SemaRef](DeclContext *DC, NamedDecl *TemplateParam,
MultiLevelTemplateArgumentList &Args,
unsigned NewIndex) -> NamedDecl * {
if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
return transformTemplateTypeParam(SemaRef, DC, TTP, Args,
TTP->getDepth(), NewIndex);
if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
return transformTemplateParam(SemaRef, DC, TTP, Args, NewIndex,
TTP->getDepth());
if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
return transformTemplateParam(SemaRef, DC, NTTP, Args, NewIndex,
NTTP->getDepth());
return nullptr;
};

for (unsigned AliasTemplateParamIdx : DeducedAliasTemplateParams) {
auto *TP = AliasTemplate->getTemplateParameters()->getParam(
Expand All @@ -2909,9 +2938,9 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
MultiLevelTemplateArgumentList Args;
Args.setKind(TemplateSubstitutionKind::Rewrite);
Args.addOuterTemplateArguments(TransformedDeducedAliasArgs);
NamedDecl *NewParam =
TransformTemplateParameter(AliasTemplate->getDeclContext(), TP, Args,
/*NewIndex*/ FPrimeTemplateParams.size());
NamedDecl *NewParam = transformTemplateParameter(
SemaRef, AliasTemplate->getDeclContext(), TP, Args,
/*NewIndex*/ FPrimeTemplateParams.size());
FPrimeTemplateParams.push_back(NewParam);

auto NewTemplateArgument = Context.getCanonicalTemplateArgument(
Expand All @@ -2927,8 +2956,8 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
// We take a shortcut here, it is ok to reuse the
// TemplateArgsForBuildingFPrime.
Args.addOuterTemplateArguments(TemplateArgsForBuildingFPrime);
NamedDecl *NewParam = TransformTemplateParameter(
F->getDeclContext(), TP, Args, FPrimeTemplateParams.size());
NamedDecl *NewParam = transformTemplateParameter(
SemaRef, F->getDeclContext(), TP, Args, FPrimeTemplateParams.size());
FPrimeTemplateParams.push_back(NewParam);

assert(TemplateArgsForBuildingFPrime[FTemplateParamIdx].isNull() &&
Expand All @@ -2938,16 +2967,8 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
Context.getInjectedTemplateArg(NewParam));
}
// Substitute new template parameters into requires-clause if present.
Expr *RequiresClause = nullptr;
if (Expr *InnerRC = F->getTemplateParameters()->getRequiresClause()) {
MultiLevelTemplateArgumentList Args;
Args.setKind(TemplateSubstitutionKind::Rewrite);
Args.addOuterTemplateArguments(TemplateArgsForBuildingFPrime);
ExprResult E = SemaRef.SubstExpr(InnerRC, Args);
if (E.isInvalid())
return;
RequiresClause = E.getAs<Expr>();
}
Expr *RequiresClause =
transformRequireClause(SemaRef, F, TemplateArgsForBuildingFPrime);
// FIXME: implement the is_deducible constraint per C++
// [over.match.class.deduct]p3.3:
// ... and a constraint that is satisfied if and only if the arguments
Expand Down Expand Up @@ -3013,11 +3034,102 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
}
}

// Build an aggregate deduction guide for a type alias template.
FunctionTemplateDecl *DeclareAggregateDeductionGuideForTypeAlias(
Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate,
MutableArrayRef<QualType> ParamTypes, SourceLocation Loc) {
TemplateDecl *RHSTemplate =
getRHSTemplateDeclAndArgs(SemaRef, AliasTemplate).first;
if (!RHSTemplate)
return nullptr;
auto *RHSDeductionGuide = SemaRef.DeclareAggregateDeductionGuideFromInitList(
RHSTemplate, ParamTypes, Loc);
if (!RHSDeductionGuide)
return nullptr;

LocalInstantiationScope Scope(SemaRef);

// Build a new template parameter list for the synthesized aggregate deduction
// guide by transforming the one from RHSDeductionGuide.
SmallVector<NamedDecl *> TransformedTemplateParams;
// Template args that refer to the rebuilt template parameters.
// All template arguments must be initialized in advance.
SmallVector<TemplateArgument> TransformedTemplateArgs(
RHSDeductionGuide->getTemplateParameters()->size());
for (auto *TP : *RHSDeductionGuide->getTemplateParameters()) {
// Rebuild any internal references to earlier parameters and reindex as
// we go.
MultiLevelTemplateArgumentList Args;
Args.setKind(TemplateSubstitutionKind::Rewrite);
Args.addOuterTemplateArguments(TransformedTemplateArgs);
NamedDecl *NewParam = transformTemplateParameter(
SemaRef, AliasTemplate->getDeclContext(), TP, Args,
/*NewIndex=*/TransformedTemplateParams.size());

TransformedTemplateArgs[TransformedTemplateParams.size()] =
SemaRef.Context.getCanonicalTemplateArgument(
SemaRef.Context.getInjectedTemplateArg(NewParam));
TransformedTemplateParams.push_back(NewParam);
}
// FIXME: implement the is_deducible constraint per C++
// [over.match.class.deduct]p3.3.
Expr *TransformedRequiresClause = transformRequireClause(
SemaRef, RHSDeductionGuide, TransformedTemplateArgs);
auto *TransformedTemplateParameterList = TemplateParameterList::Create(
SemaRef.Context, AliasTemplate->getTemplateParameters()->getTemplateLoc(),
AliasTemplate->getTemplateParameters()->getLAngleLoc(),
TransformedTemplateParams,
AliasTemplate->getTemplateParameters()->getRAngleLoc(),
TransformedRequiresClause);
auto *TransformedTemplateArgList = TemplateArgumentList::CreateCopy(
SemaRef.Context, TransformedTemplateArgs);

if (auto *TransformedDeductionGuide = SemaRef.InstantiateFunctionDeclaration(
RHSDeductionGuide, TransformedTemplateArgList,
AliasTemplate->getLocation(),
Sema::CodeSynthesisContext::BuildingDeductionGuides)) {
auto *GD =
llvm::dyn_cast<clang::CXXDeductionGuideDecl>(TransformedDeductionGuide);
FunctionTemplateDecl *Result = buildDeductionGuide(
SemaRef, AliasTemplate, TransformedTemplateParameterList,
GD->getCorrespondingConstructor(), GD->getExplicitSpecifier(),
GD->getTypeSourceInfo(), AliasTemplate->getBeginLoc(),
AliasTemplate->getLocation(), AliasTemplate->getEndLoc(),
GD->isImplicit());
cast<CXXDeductionGuideDecl>(Result->getTemplatedDecl())
->setDeductionCandidateKind(DeductionCandidate::Aggregate);
return Result;
}
return nullptr;
}

} // namespace

FunctionTemplateDecl *Sema::DeclareImplicitDeductionGuideFromInitList(
FunctionTemplateDecl *Sema::DeclareAggregateDeductionGuideFromInitList(
TemplateDecl *Template, MutableArrayRef<QualType> ParamTypes,
SourceLocation Loc) {
llvm::FoldingSetNodeID ID;
ID.AddPointer(Template);
for (auto &T : ParamTypes)
T.getCanonicalType().Profile(ID);
unsigned Hash = ID.ComputeHash();

auto Found = AggregateDeductionCandidates.find(Hash);
if (Found != AggregateDeductionCandidates.end()) {
CXXDeductionGuideDecl *GD = Found->getSecond();
return GD->getDescribedFunctionTemplate();
}

if (auto *AliasTemplate = llvm::dyn_cast<TypeAliasTemplateDecl>(Template)) {
if (auto *FTD = DeclareAggregateDeductionGuideForTypeAlias(
*this, AliasTemplate, ParamTypes, Loc)) {
auto *GD = cast<CXXDeductionGuideDecl>(FTD->getTemplatedDecl());
GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
AggregateDeductionCandidates[Hash] = GD;
return FTD;
}
}

if (CXXRecordDecl *DefRecord =
cast<CXXRecordDecl>(Template->getTemplatedDecl())->getDefinition()) {
if (TemplateDecl *DescribedTemplate =
Expand Down Expand Up @@ -3050,10 +3162,13 @@ FunctionTemplateDecl *Sema::DeclareImplicitDeductionGuideFromInitList(
Transform.NestedPattern ? Transform.NestedPattern : Transform.Template;
ContextRAII SavedContext(*this, Pattern->getTemplatedDecl());

auto *DG = cast<FunctionTemplateDecl>(
auto *FTD = cast<FunctionTemplateDecl>(
Transform.buildSimpleDeductionGuide(ParamTypes));
SavedContext.pop();
return DG;
auto *GD = cast<CXXDeductionGuideDecl>(FTD->getTemplatedDecl());
GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
AggregateDeductionCandidates[Hash] = GD;
return FTD;
}

void Sema::DeclareImplicitDeductionGuides(TemplateDecl *Template,
Expand Down
12 changes: 12 additions & 0 deletions clang/test/SemaTemplate/deduction-guide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,15 @@ G g = {1};
// CHECK: FunctionTemplateDecl
// CHECK: |-CXXDeductionGuideDecl {{.*}} implicit <deduction guide for G> 'auto (T) -> G<T>' aggregate
// CHECK: `-CXXDeductionGuideDecl {{.*}} implicit used <deduction guide for G> 'auto (int) -> G<int>' implicit_instantiation aggregate

template<typename X>
using AG = G<X>;
AG ag = {1};
// Verify that the aggregate deduction guide for alias templates is built.
// CHECK-LABEL: Dumping <deduction guide for AG>
// CHECK: FunctionTemplateDecl
// CHECK: |-CXXDeductionGuideDecl {{.*}} 'auto (type-parameter-0-0) -> G<type-parameter-0-0>'
// CHECK: `-CXXDeductionGuideDecl {{.*}} 'auto (int) -> G<int>' implicit_instantiation
// CHECK: |-TemplateArgument type 'int'
// CHECK: | `-BuiltinType {{.*}} 'int'
// CHECK: `-ParmVarDecl {{.*}} 'int'

0 comments on commit 5e77dfe

Please sign in to comment.