Skip to content

Commit

Permalink
[Clang][Sema] Improve support for explicit specializations of constra…
Browse files Browse the repository at this point in the history
…ined member functions & member function templates (llvm#88963)

Consider the following snippet from the discussion of CWG2847 on the core reflector:
```
template<typename T>
concept C = sizeof(T) <= sizeof(long);

template<typename T>
struct A 
{
    template<typename U>
    void f(U) requires C<U>; // #1, declares a function template 

    void g() requires C<T>; // rust-lang#2, declares a function

    template<>
    void f(char);  // rust-lang#3, an explicit specialization of a function template that declares a function
};

template<>
template<typename U>
void A<short>::f(U) requires C<U>; // rust-lang#4, an explicit specialization of a function template that declares a function template

template<>
template<>
void A<int>::f(int); // rust-lang#5, an explicit specialization of a function template that declares a function

template<>
void A<long>::g(); // rust-lang#6, an explicit specialization of a function that declares a function
```

A number of problems exist:
- Clang rejects `rust-lang#4` because the trailing _requires-clause_ has `U`
substituted with the wrong template parameter depth when
`Sema::AreConstraintExpressionsEqual` is called to determine whether it
matches the trailing _requires-clause_ of the implicitly instantiated
function template.
- Clang rejects `rust-lang#5` because the function template specialization
instantiated from `A<int>::f` has a trailing _requires-clause_, but `rust-lang#5`
does not (nor can it have one as it isn't a templated function).
- Clang rejects `rust-lang#6` for the same reasons it rejects `rust-lang#5`.

This patch resolves these issues by making the following changes:
- To fix `rust-lang#4`, `Sema::AreConstraintExpressionsEqual` is passed
`FunctionTemplateDecl`s when comparing the trailing _requires-clauses_
of `rust-lang#4` and the function template instantiated from `#1`.
- To fix `rust-lang#5` and `rust-lang#6`, the trailing _requires-clauses_ are not compared
for explicit specializations that declare functions.

In addition to these changes, `CheckMemberSpecialization` now considers
constraint satisfaction/constraint partial ordering when determining
which member function is specialized by an explicit specialization of a
member function for an implicit instantiation of a class template (we
previously would select the first function that has the same type as the
explicit specialization). With constraints taken under consideration, we
match EDG's behavior for these declarations.
  • Loading branch information
sdkrystian committed May 8, 2024
1 parent 83f3b1c commit 34ae226
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 68 deletions.
4 changes: 4 additions & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,10 @@ Bug Fixes to C++ Support
until the noexcept-specifier is instantiated.
- Fix a crash when an implicitly declared ``operator==`` function with a trailing requires-clause has its
constraints compared to that of another declaration.
- Fix a bug where explicit specializations of member functions/function templates would have substitution
performed incorrectly when checking constraints. Fixes (#GH90349).
- Clang now allows constrained member functions to be explicitly specialized for an implicit instantiation
of a class template.

Bug Fixes to AST Handling
^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -5437,6 +5437,11 @@ def note_function_template_spec_matched : Note<
def err_function_template_partial_spec : Error<
"function template partial specialization is not allowed">;

def err_function_member_spec_ambiguous : Error<
"ambiguous member function specialization %q0 of %q1">;
def note_function_member_spec_matched : Note<
"member function specialization matches %0">;

// C++ Template Instantiation
def err_template_recursion_depth_exceeded : Error<
"recursive template instantiation exceeded maximum depth of %0">,
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -9739,6 +9739,9 @@ class Sema final : public SemaBase {
const PartialDiagnostic &CandidateDiag,
bool Complain = true, QualType TargetType = QualType());

FunctionDecl *getMoreConstrainedFunction(FunctionDecl *FD1,
FunctionDecl *FD2);

///@}

//
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/Sema/SemaConcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ static const Expr *SubstituteConstraintExpressionWithoutSatisfaction(
// this may happen while we're comparing two templates' constraint
// equivalence.
LocalInstantiationScope ScopeForParameters(S);
if (auto *FD = llvm::dyn_cast<FunctionDecl>(DeclInfo.getDecl()))
if (auto *FD = DeclInfo.getDecl()->getAsFunction())
for (auto *PVD : FD->parameters())
ScopeForParameters.InstantiatedLocal(PVD, PVD);

Expand Down
72 changes: 19 additions & 53 deletions clang/lib/Sema/SemaOverload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,8 @@ static bool IsOverloadOrOverrideImpl(Sema &SemaRef, FunctionDecl *New,
if (New->isMSVCRTEntryPoint())
return false;

NamedDecl *OldDecl = Old;
NamedDecl *NewDecl = New;
FunctionTemplateDecl *OldTemplate = Old->getDescribedFunctionTemplate();
FunctionTemplateDecl *NewTemplate = New->getDescribedFunctionTemplate();

Expand Down Expand Up @@ -1347,6 +1349,8 @@ static bool IsOverloadOrOverrideImpl(Sema &SemaRef, FunctionDecl *New,
// references to non-instantiated entities during constraint substitution.
// GH78101.
if (NewTemplate) {
OldDecl = OldTemplate;
NewDecl = NewTemplate;
// C++ [temp.over.link]p4:
// The signature of a function template consists of its function
// signature, its return type and its template parameter list. The names
Expand Down Expand Up @@ -1506,13 +1510,14 @@ static bool IsOverloadOrOverrideImpl(Sema &SemaRef, FunctionDecl *New,
}
}

if (!UseOverrideRules) {
if (!UseOverrideRules &&
New->getTemplateSpecializationKind() != TSK_ExplicitSpecialization) {
Expr *NewRC = New->getTrailingRequiresClause(),
*OldRC = Old->getTrailingRequiresClause();
if ((NewRC != nullptr) != (OldRC != nullptr))
return true;

if (NewRC && !SemaRef.AreConstraintExpressionsEqual(Old, OldRC, New, NewRC))
if (NewRC &&
!SemaRef.AreConstraintExpressionsEqual(OldDecl, OldRC, NewDecl, NewRC))
return true;
}

Expand Down Expand Up @@ -10695,29 +10700,10 @@ bool clang::isBetterOverloadCandidate(
// -— F1 and F2 are non-template functions with the same
// parameter-type-lists, and F1 is more constrained than F2 [...],
if (!Cand1IsSpecialization && !Cand2IsSpecialization &&
sameFunctionParameterTypeLists(S, Cand1, Cand2)) {
FunctionDecl *Function1 = Cand1.Function;
FunctionDecl *Function2 = Cand2.Function;
if (FunctionDecl *MF = Function1->getInstantiatedFromMemberFunction())
Function1 = MF;
if (FunctionDecl *MF = Function2->getInstantiatedFromMemberFunction())
Function2 = MF;

const Expr *RC1 = Function1->getTrailingRequiresClause();
const Expr *RC2 = Function2->getTrailingRequiresClause();
if (RC1 && RC2) {
bool AtLeastAsConstrained1, AtLeastAsConstrained2;
if (S.IsAtLeastAsConstrained(Function1, RC1, Function2, RC2,
AtLeastAsConstrained1) ||
S.IsAtLeastAsConstrained(Function2, RC2, Function1, RC1,
AtLeastAsConstrained2))
return false;
if (AtLeastAsConstrained1 != AtLeastAsConstrained2)
return AtLeastAsConstrained1;
} else if (RC1 || RC2) {
return RC1 != nullptr;
}
}
sameFunctionParameterTypeLists(S, Cand1, Cand2) &&
S.getMoreConstrainedFunction(Cand1.Function, Cand2.Function) ==
Cand1.Function)
return true;

// -- F1 is a constructor for a class D, F2 is a constructor for a base
// class B of D, and for all arguments the corresponding parameters of
Expand Down Expand Up @@ -13385,25 +13371,6 @@ Sema::resolveAddressOfSingleOverloadCandidate(Expr *E, DeclAccessPair &Pair) {
static_cast<int>(CUDA().IdentifyPreference(Caller, FD2));
};

auto CheckMoreConstrained = [&](FunctionDecl *FD1,
FunctionDecl *FD2) -> std::optional<bool> {
if (FunctionDecl *MF = FD1->getInstantiatedFromMemberFunction())
FD1 = MF;
if (FunctionDecl *MF = FD2->getInstantiatedFromMemberFunction())
FD2 = MF;
SmallVector<const Expr *, 1> AC1, AC2;
FD1->getAssociatedConstraints(AC1);
FD2->getAssociatedConstraints(AC2);
bool AtLeastAsConstrained1, AtLeastAsConstrained2;
if (IsAtLeastAsConstrained(FD1, AC1, FD2, AC2, AtLeastAsConstrained1))
return std::nullopt;
if (IsAtLeastAsConstrained(FD2, AC2, FD1, AC1, AtLeastAsConstrained2))
return std::nullopt;
if (AtLeastAsConstrained1 == AtLeastAsConstrained2)
return std::nullopt;
return AtLeastAsConstrained1;
};

// Don't use the AddressOfResolver because we're specifically looking for
// cases where we have one overload candidate that lacks
// enable_if/pass_object_size/...
Expand Down Expand Up @@ -13440,15 +13407,14 @@ Sema::resolveAddressOfSingleOverloadCandidate(Expr *E, DeclAccessPair &Pair) {
}
// FD has the same CUDA prefernece than Result. Continue check
// constraints.
std::optional<bool> MoreConstrainedThanPrevious =
CheckMoreConstrained(FD, Result);
if (!MoreConstrainedThanPrevious) {
IsResultAmbiguous = true;
AmbiguousDecls.push_back(FD);
FunctionDecl *MoreConstrained = getMoreConstrainedFunction(FD, Result);
if (MoreConstrained != FD) {
if (!MoreConstrained) {
IsResultAmbiguous = true;
AmbiguousDecls.push_back(FD);
}
continue;
}
if (!*MoreConstrainedThanPrevious)
continue;
// FD is more constrained - replace Result with it.
}
FoundBetter();
Expand All @@ -13467,7 +13433,7 @@ Sema::resolveAddressOfSingleOverloadCandidate(Expr *E, DeclAccessPair &Pair) {
// constraints.
if (getLangOpts().CUDA && CheckCUDAPreference(Skipped, Result) != 0)
continue;
if (!CheckMoreConstrained(Skipped, Result))
if (!getMoreConstrainedFunction(Skipped, Result))
return nullptr;
}
Pair = DAP;
Expand Down
57 changes: 43 additions & 14 deletions clang/lib/Sema/SemaTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10339,24 +10339,53 @@ Sema::CheckMemberSpecialization(NamedDecl *Member, LookupResult &Previous) {
if (Previous.empty()) {
// Nowhere to look anyway.
} else if (FunctionDecl *Function = dyn_cast<FunctionDecl>(Member)) {
SmallVector<FunctionDecl *> Candidates;
bool Ambiguous = false;
for (LookupResult::iterator I = Previous.begin(), E = Previous.end();
I != E; ++I) {
NamedDecl *D = (*I)->getUnderlyingDecl();
if (CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(D)) {
QualType Adjusted = Function->getType();
if (!hasExplicitCallingConv(Adjusted))
Adjusted = adjustCCAndNoReturn(Adjusted, Method->getType());
// This doesn't handle deduced return types, but both function
// declarations should be undeduced at this point.
if (Context.hasSameType(Adjusted, Method->getType())) {
FoundInstantiation = *I;
Instantiation = Method;
InstantiatedFrom = Method->getInstantiatedFromMemberFunction();
MSInfo = Method->getMemberSpecializationInfo();
break;
}
CXXMethodDecl *Method =
dyn_cast<CXXMethodDecl>((*I)->getUnderlyingDecl());
if (!Method)
continue;
QualType Adjusted = Function->getType();
if (!hasExplicitCallingConv(Adjusted))
Adjusted = adjustCCAndNoReturn(Adjusted, Method->getType());
// This doesn't handle deduced return types, but both function
// declarations should be undeduced at this point.
if (!Context.hasSameType(Adjusted, Method->getType()))
continue;
if (ConstraintSatisfaction Satisfaction;
Method->getTrailingRequiresClause() &&
(CheckFunctionConstraints(Method, Satisfaction,
/*UsageLoc=*/Member->getLocation(),
/*ForOverloadResolution=*/true) ||
!Satisfaction.IsSatisfied))
continue;
Candidates.push_back(Method);
FunctionDecl *MoreConstrained =
Instantiation ? getMoreConstrainedFunction(
Method, cast<FunctionDecl>(Instantiation))
: Method;
if (!MoreConstrained) {
Ambiguous = true;
continue;
}
if (MoreConstrained == Method) {
Ambiguous = false;
FoundInstantiation = *I;
Instantiation = Method;
InstantiatedFrom = Method->getInstantiatedFromMemberFunction();
MSInfo = Method->getMemberSpecializationInfo();
}
}
if (Ambiguous) {
Diag(Member->getLocation(), diag::err_function_member_spec_ambiguous)
<< Member << (InstantiatedFrom ? InstantiatedFrom : Instantiation);
for (FunctionDecl *Candidate : Candidates)
Diag(Candidate->getLocation(), diag::note_function_member_spec_matched)
<< Candidate;
return true;
}
} else if (isa<VarDecl>(Member)) {
VarDecl *PrevVar;
if (Previous.isSingleResult() &&
Expand Down
32 changes: 32 additions & 0 deletions clang/lib/Sema/SemaTemplateDeduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5852,6 +5852,38 @@ UnresolvedSetIterator Sema::getMostSpecialized(
return SpecEnd;
}

/// Returns the more constrained function according to the rules of
/// partial ordering by constraints (C++ [temp.constr.order]).
///
/// \param FD1 the first function
///
/// \param FD2 the second function
///
/// \returns the more constrained function. If neither function is
/// more constrained, returns NULL.
FunctionDecl *Sema::getMoreConstrainedFunction(FunctionDecl *FD1,
FunctionDecl *FD2) {
assert(!FD1->getDescribedTemplate() && !FD2->getDescribedTemplate() &&
"not for function templates");
FunctionDecl *F1 = FD1;
if (FunctionDecl *MF = FD1->getInstantiatedFromMemberFunction())
F1 = MF;
FunctionDecl *F2 = FD2;
if (FunctionDecl *MF = FD2->getInstantiatedFromMemberFunction())
F2 = MF;
llvm::SmallVector<const Expr *, 1> AC1, AC2;
F1->getAssociatedConstraints(AC1);
F2->getAssociatedConstraints(AC2);
bool AtLeastAsConstrained1, AtLeastAsConstrained2;
if (IsAtLeastAsConstrained(F1, AC1, F2, AC2, AtLeastAsConstrained1))
return nullptr;
if (IsAtLeastAsConstrained(F2, AC2, F1, AC1, AtLeastAsConstrained2))
return nullptr;
if (AtLeastAsConstrained1 == AtLeastAsConstrained2)
return nullptr;
return AtLeastAsConstrained1 ? FD1 : FD2;
}

/// Determine whether one partial specialization, P1, is at least as
/// specialized than another, P2.
///
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/Sema/SemaTemplateInstantiate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function,
TemplateArgs->asArray(),
/*Final=*/false);

if (RelativeToPrimary &&
(Function->getTemplateSpecializationKind() ==
TSK_ExplicitSpecialization ||
(Function->getFriendObjectKind() &&
!Function->getPrimaryTemplate()->getFriendObjectKind())))
return Response::UseNextDecl(Function);

// If this function was instantiated from a specialized member that is
// a function template, we're done.
assert(Function->getPrimaryTemplate() && "No function template?");
Expand Down
60 changes: 60 additions & 0 deletions clang/test/CXX/temp/temp.spec/temp.expl.spec/p14-23.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: %clang_cc1 -std=c++20 -verify %s

template<int I>
concept C = I >= 4;

template<int I>
concept D = I < 8;

template<int I>
struct A {
constexpr static int f() { return 0; }
constexpr static int f() requires C<I> && D<I> { return 1; }
constexpr static int f() requires C<I> { return 2; }

constexpr static int g() requires C<I> { return 0; } // #candidate-0
constexpr static int g() requires D<I> { return 1; } // #candidate-1

constexpr static int h() requires C<I> { return 0; } // expected-note {{member declaration nearly matches}}
};

template<>
constexpr int A<2>::f() { return 3; }

template<>
constexpr int A<4>::f() { return 4; }

template<>
constexpr int A<8>::f() { return 5; }

static_assert(A<3>::f() == 0);
static_assert(A<5>::f() == 1);
static_assert(A<9>::f() == 2);
static_assert(A<2>::f() == 3);
static_assert(A<4>::f() == 4);
static_assert(A<8>::f() == 5);

template<>
constexpr int A<0>::g() { return 2; }

template<>
constexpr int A<8>::g() { return 3; }

template<>
constexpr int A<6>::g() { return 4; } // expected-error {{ambiguous member function specialization 'A<6>::g' of 'A::g'}}
// expected-note@#candidate-0 {{member function specialization matches 'g'}}
// expected-note@#candidate-1 {{member function specialization matches 'g'}}

static_assert(A<9>::g() == 0);
static_assert(A<1>::g() == 1);
static_assert(A<0>::g() == 2);
static_assert(A<8>::g() == 3);

template<>
constexpr int A<4>::h() { return 1; }

template<>
constexpr int A<0>::h() { return 2; } // expected-error {{out-of-line definition of 'h' does not match any declaration in 'A<0>'}}

static_assert(A<5>::h() == 0);
static_assert(A<4>::h() == 1);

0 comments on commit 34ae226

Please sign in to comment.