From 549ca8fc684b2e0efcbde983b1d1b8be78bd6d80 Mon Sep 17 00:00:00 2001 From: "Larsen, Steffen" Date: Thu, 11 Dec 2025 05:43:04 -0800 Subject: [PATCH 1/6] [SYCL][clang] Fix more free-function kernel integration header cases This commit fixes a number of known issues when the integration header generates prototypes of the free-function kernels. These changes focus on the additional removal of aliasing and proper handling of templated template arguments. This commit also adds disabled test cases for a known issue with unresolved nested templated type aliases. These are cases for future work. Signed-off-by: Larsen, Steffen Co-authored-by: Sachkov, Alexey --- clang/include/clang/AST/TypeBase.h | 3 + clang/lib/AST/Type.cpp | 35 +++ clang/lib/Sema/SemaSYCL.cpp | 279 ++++++++++++++---- .../free-function-kernel-type-alias-arg.cpp | 80 ++++- 4 files changed, 335 insertions(+), 62 deletions(-) diff --git a/clang/include/clang/AST/TypeBase.h b/clang/include/clang/AST/TypeBase.h index c6bf312268647..899b9939b33d2 100644 --- a/clang/include/clang/AST/TypeBase.h +++ b/clang/include/clang/AST/TypeBase.h @@ -2956,6 +2956,9 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase { return TST; } + const TemplateSpecializationType * + getAsTemplateSpecializationTypeWithoutAliases(const ASTContext &Ctx) const; + /// Member-template getAsAdjusted. Look through specific kinds /// of sugar (parens, attributes, etc) for an instance of \. /// This is used when you need to walk over sugar nodes that represent some diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index 8a0b11bfcedac..76c2b170835f2 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -1925,6 +1925,41 @@ Type::getAsNonAliasTemplateSpecializationType() const { return TST; } +const TemplateSpecializationType * +Type::getAsTemplateSpecializationTypeWithoutAliases( + const ASTContext &Ctx) const { + const TemplateSpecializationType *TST = + getAsNonAliasTemplateSpecializationType(); + if (!TST) + return TST; + + // Ensure the template arguments of the template specialization type are + // without aliases. + SmallVector ArgsWithoutAliases; + ArgsWithoutAliases.reserve(TST->template_arguments().size()); + for (const TemplateArgument &TA : TST->template_arguments()) { + if (TA.getKind() == TemplateArgument::ArgKind::Type) { + QualType TAQTy = TA.getAsType(); + const Type *TATy = TAQTy->getUnqualifiedDesugaredType(); + if (isa(TATy)) + TATy = TATy->getAsTemplateSpecializationTypeWithoutAliases(Ctx); + ArgsWithoutAliases.emplace_back(QualType(TATy, TAQTy.getCVRQualifiers())); + } else if (TA.getKind() == TemplateArgument::ArgKind::Template) { + TemplateName TN = TA.getAsTemplate(); + while (std::optional DesugaredTN = + TN.desugar(/*IgnoreDeduced=*/false)) + TN = *DesugaredTN; + ArgsWithoutAliases.emplace_back(TN); + } else { + ArgsWithoutAliases.push_back(TA); + } + } + return Ctx + .getTemplateSpecializationType(TST->getKeyword(), TST->getTemplateName(), + ArgsWithoutAliases, {}, QualType{}) + ->getAs(); +} + NestedNameSpecifier Type::getPrefix() const { switch (getTypeClass()) { case Type::DependentName: diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 639e3a2d16cc7..923cb2a7ccd6c 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -6626,6 +6626,195 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { [](raw_ostream &, const NamespaceDecl *) {}, OS, DC); } +/// Dedicated visitor which helps with printing of kernel arguments in forward +/// declarations of free function kernels which are declared as function +/// templates. +/// +/// Based on: +/// \code +/// template +/// void foo(T1 a, int b, T2 c); +/// \endcode +/// +/// It prints into the output stream "T1, int, T2". +/// +/// The main complexity (which motivates addition of such visitor) comes from +/// the fact that there could be type aliases and default template arguments. +/// For example: +/// \code +/// template +/// void kernel(sycl::accessor); +/// template void kernel(sycl::accessor); +/// \endcode +/// sycl::accessor has many template arguments which have default values. If +/// we iterate over non-canonicalized argument type, we don't get those default +/// values and we don't get necessary namespace qualifiers for all the template +/// arguments. If we iterate over canonicalized argument type, then all +/// references to T will be replaced with something like type-argument-X-Y. +/// What this visitor does is it iterates over both in sync, picking the right +/// values from one or another. +/// +/// Moral of the story: drop integration header ASAP (but that is blocked +/// by support for 3rd-party host compilers, which is important). +class FreeFunctionTemplateKernelArgsPrinter + : public ConstTemplateArgumentVisitor> { + raw_ostream &O; + PrintingPolicy &Policy; + ASTContext &Context; + + using Base = + ConstTemplateArgumentVisitor>; + + void PrintTemplateDeclName(const TemplateDecl *TD, + ArrayRef SpecArgs) {} + +public: + FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy, + ASTContext &Context) + : O(O), Policy(Policy), Context(Context) {} + + void Visit(const TemplateSpecializationType *T, + const TemplateSpecializationType *CT) { + ArrayRef SpecArgs = T->template_arguments(); + ArrayRef DeclArgs = CT->template_arguments(); + + const TemplateDecl *TD = CT->getTemplateName().getAsTemplateDecl(); + if (!TD->getIdentifier()) + TD = T->getTemplateName().getAsTemplateDecl(); + TD->printQualifiedName(O); + + O << "<"; + for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), + SE = SpecArgs.size(); + I < E; ++I) { + if (I != 0) + O << ", "; + // If we have a specialized argument, use it. Otherwise fallback to a + // default argument. + // We pass specialized arguments in case there are references to them + // from other types. + // FIXME: passing SpecArgs here is incorrect. It refers to template + // arguments of a single function argument, but DeclArgs contain + // references (in form of depth-index) to template arguments of the + // function itself which results in incorrect integration header being + // produced. + Base::Visit(I < SE ? SpecArgs[I] : DeclArgs[I], SpecArgs); + } + O << ">"; + } + + // Internal version of the function above that is used when template argument + // is a template by itself + void Visit(const TemplateSpecializationType *T, + ArrayRef SpecArgs) { + const TemplateDecl *TD = T->getTemplateName().getAsTemplateDecl(); + const auto *TTPD = dyn_cast(TD); + if (TTPD && !TTPD->getIdentifier()) + SpecArgs[TTPD->getIndex()].print(Policy, O, /* IncludeType = */ false); + else + TD->printQualifiedName(O); + O << "<"; + ArrayRef DeclArgs = T->template_arguments(); + for (size_t I = 0, E = DeclArgs.size(); I < E; ++I) { + if (I != 0) + O << ", "; + Base::Visit(DeclArgs[I], SpecArgs); + } + O << ">"; + } + + void VisitNullTemplateArgument(const TemplateArgument &, + ArrayRef) { + llvm_unreachable("If template argument has not been deduced, then we can't " + "forward-declare it, something went wrong"); + } + + void VisitTypeTemplateArgument(const TemplateArgument &Arg, + ArrayRef SpecArgs) { + // If we reference an existing template argument without a known identifier, + // print it instead. + const auto *TPT = dyn_cast(Arg.getAsType()); + if (TPT && !TPT->getIdentifier()) { + SpecArgs[TPT->getIndex()].print(Policy, O, /* IncludeType = */ false); + return; + } + + const auto *TST = dyn_cast(Arg.getAsType()); + if (TST && Arg.isInstantiationDependent()) { + // This is an instantiation dependent template specialization, meaning + // that some of its arguments reference template arguments of the free + // function kernel itself. + Visit(TST, SpecArgs); + return; + } + + Arg.print(Policy, O, /* IncludeType = */ false); + } + + void VisitDeclarationTemplateArgument(const TemplateArgument &, + ArrayRef) { + llvm_unreachable("Free function kernels cannot have non-type template " + "arguments which are pointers or references"); + } + + void VisitNullPtrTemplateArgument(const TemplateArgument &, + ArrayRef) { + llvm_unreachable("Free function kernels cannot have non-type template " + "arguments which are pointers or references"); + } + + void VisitIntegralTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + Arg.print(Policy, O, /* IncludeType = */ false); + } + + void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + Arg.print(Policy, O, /* IncludeType = */ false); + } + + void VisitTemplateTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + Arg.print(Policy, O, /* IncludeType = */ false); + } + + void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + // Likely does not work similar to the one above + Arg.print(Policy, O, /* IncludeType = */ false); + } + + void VisitExpressionTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + Expr *E = Arg.getAsExpr(); + assert(E && "Failed to get an Expr for an Expression template arg?"); + + if (Arg.isInstantiationDependent() || + E->getType().getTypePtr()->isScopedEnumeralType()) { + // Scoped enumerations can't be implicitly cast from integers, so + // we don't need to evaluate them. + // If expression is instantiation-dependent, then we can't evaluate it + // either, let's fallback to default printing mechanism. + Arg.print(Policy, O, /* IncludeType = */ false); + return; + } + + Expr::EvalResult Res; + [[maybe_unused]] bool Success = + Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); + assert(Success && "invalid non-type template argument?"); + assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); + Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context); + } + + void VisitPackTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + Arg.print(Policy, O, /* IncludeType = */ false); + } +}; + class FreeFunctionPrinter { raw_ostream &O; PrintingPolicy &Policy; @@ -6789,7 +6978,10 @@ class FreeFunctionPrinter { llvm::raw_svector_ostream ParmListOstream{ParamList}; Policy.SuppressTagKeyword = true; - for (ParmVarDecl *Param : Parameters) { + FreeFunctionTemplateKernelArgsPrinter Printer(ParmListOstream, Policy, + Context); + + for (const ParmVarDecl *Param : Parameters) { if (FirstParam) FirstParam = false; else @@ -6822,53 +7014,11 @@ class FreeFunctionPrinter { } const TemplateSpecializationType *TSTAsNonAlias = - TST->getAsNonAliasTemplateSpecializationType(); + TST->getAsTemplateSpecializationTypeWithoutAliases(Context); if (TSTAsNonAlias) TST = TSTAsNonAlias; - TemplateName CTN = CTST->getTemplateName(); - CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream); - ParmListOstream << "<"; - - ArrayRef SpecArgs = TST->template_arguments(); - ArrayRef DeclArgs = CTST->template_arguments(); - - auto TemplateArgPrinter = [&](const TemplateArgument &Arg) { - if (Arg.getKind() != TemplateArgument::ArgKind::Expression || - Arg.isInstantiationDependent()) { - Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); - return; - } - - Expr *E = Arg.getAsExpr(); - assert(E && "Failed to get an Expr for an Expression template arg?"); - if (E->getType().getTypePtr()->isScopedEnumeralType()) { - // Scoped enumerations can't be implicitly cast from integers, so - // we don't need to evaluate them. - Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); - return; - } - - Expr::EvalResult Res; - [[maybe_unused]] bool Success = - Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); - assert(Success && "invalid non-type template argument?"); - assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); - Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(), - &Context); - }; - - for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), - SE = SpecArgs.size(); - I < E; ++I) { - if (I != 0) - ParmListOstream << ", "; - // If we have a specialized argument, use it. Otherwise fallback to a - // default argument. - TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]); - } - - ParmListOstream << ">"; + Printer.Visit(TST, CTST); } return ParamList.str().str(); } @@ -6886,26 +7036,39 @@ class FreeFunctionPrinter { std::string getTemplateParameters(const clang::TemplateParameterList *TPL) { std::string TemplateParams{"template <"}; bool FirstParam{true}; - for (NamedDecl *Param : *TPL) { + for (const NamedDecl *Param : *TPL) { if (!FirstParam) TemplateParams += ", "; FirstParam = false; - if (const auto *TemplateParam = dyn_cast(Param)) { - TemplateParams += - TemplateParam->wasDeclaredWithTypename() ? "typename " : "class "; - if (TemplateParam->isParameterPack()) - TemplateParams += "... "; - TemplateParams += TemplateParam->getNameAsString(); - } else if (const auto *NonTypeParam = - dyn_cast(Param)) { - TemplateParams += NonTypeParam->getType().getAsString(); - TemplateParams += " "; - TemplateParams += NonTypeParam->getNameAsString(); - } + TemplateParams += getTemplateParameter(Param); } TemplateParams += "> "; return TemplateParams; } + + /// Helper method to get text representation of a template parameter. + /// \param Param The template parameter. + std::string getTemplateParameter(const NamedDecl *Param) { + auto GetTypenameOrClass = [](const auto *Param) { + return Param->wasDeclaredWithTypename() ? "typename " : "class "; + }; + if (const auto *TemplateParam = dyn_cast(Param)) { + std::string TemplateParamStr = GetTypenameOrClass(TemplateParam); + if (TemplateParam->isParameterPack()) + TemplateParamStr += "... "; + TemplateParamStr += TemplateParam->getNameAsString(); + return TemplateParamStr; + } else if (const auto *NonTypeParam = + dyn_cast(Param)) { + return NonTypeParam->getType().getAsString() + " " + + NonTypeParam->getNameAsString(); + } else if (const auto *TTParam = + dyn_cast(Param)) { + return getTemplateParameters(TTParam->getTemplateParameters()) + " " + + GetTypenameOrClass(TTParam) + TTParam->getNameAsString(); + } + return ""; + } }; void SYCLIntegrationHeader::emit(raw_ostream &O) { diff --git a/clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp b/clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp index 5d6ea216d7d38..d8717eaaab44f 100644 --- a/clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp +++ b/clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp @@ -14,6 +14,9 @@ typedef int IntTypedef; template struct Foo {}; +template +using FooUsing = Foo; + using FooIntUsing = Foo; typedef Foo FooIntTypedef; @@ -29,15 +32,35 @@ using BarUsing2 = Bar, T1>; template using BarUsingBarUsing2 = BarUsing2; +template +using BarUsingFooIntUsing = Bar; + +template +using BarUsingBarUsingFooIntUsing = BarUsingFooIntUsing; + class Baz { public: using type = BarUsing; }; +template