diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index b9a5d202af373..ffe1280c387f1 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -6626,10 +6626,12 @@ class FreeFunctionPrinter { raw_ostream &O; PrintingPolicy &Policy; bool NSInserted = false; + ASTContext &Context; public: - FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy) - : O(O), Policy(PrintPolicy) {} + FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy, + ASTContext &Context) + : O(O), Policy(PrintPolicy), Context(Context) {} /// Emits the function declaration of template free function. /// \param FTD The function declaration to print. @@ -6826,18 +6828,42 @@ class FreeFunctionPrinter { CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream); ParmListOstream << "<"; - auto SpecArgs = TST->template_arguments(); - auto DeclArgs = CTST->template_arguments(); + ArrayRef SpecArgs = TST->template_arguments(); + ArrayRef DeclArgs = CTST->template_arguments(); + + auto TemplateArgPrinter = [&](const TemplateArgument &Arg) { + if (Arg.getKind() != TemplateArgument::ArgKind::Expression || + Arg.isInstantiationDependent()) { + Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); + return; + } + + Expr *E = Arg.getAsExpr(); + assert(E && "Failed to get an Expr for an Expression template arg?"); + if (E->getType().getTypePtr()->isScopedEnumeralType()) { + // Scoped enumerations can't be implicitly cast from integers, so + // we don't need to evaluate them. + Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); + return; + } + + Expr::EvalResult Res; + [[maybe_unused]] bool Success = + Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); + assert(Success && "invalid non-type template argument?"); + assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); + Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(), + &Context); + }; for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), SE = SpecArgs.size(); I < E; ++I) { if (I != 0) ParmListOstream << ", "; - if (I < SE) // A specialized argument exists, use it - SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */); - else // Print a canonical form of a default argument - DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */); + // If we have a specialized argument, use it. Otherwise fallback to a + // default argument. + TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]); } ParmListOstream << ">"; @@ -7236,7 +7262,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { // template arguments that match default template arguments while printing // template-ids, even if the source code doesn't reference them. Policy.EnforceDefaultTemplateArgs = true; - FreeFunctionPrinter FFPrinter(O, Policy); + FreeFunctionPrinter FFPrinter(O, Policy, S.getASTContext()); if (FTD) { FFPrinter.printFreeFunctionDeclaration(FTD); if (const auto kind = K.SyclKernel->getTemplateSpecializationKind(); diff --git a/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp b/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp new file mode 100644 index 0000000000000..81120ef702d52 --- /dev/null +++ b/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp @@ -0,0 +1,110 @@ +// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s +// RUN: FileCheck -input-file=%t.h %s +// +// The purpose of this test is to ensure that forward declarations of free +// function kernels are emitted properly. +// However, this test checks a specific scenario: +// - free function argument is a template which accepts constant expressions as +// arguments + +constexpr int A = 2; +constexpr int B = 3; + +namespace ns { + +constexpr int C = 4; + +struct Foo { + static constexpr int D = 5; +}; + +enum non_class_enum { + VAL_A, + VAL_B +}; + +enum class class_enum { + VAL_A, + VAL_B +}; + +enum non_class_enum_typed : int { + VAL_C, + VAL_D +}; + +enum class class_enum_typed : int { + VAL_C, + VAL_D +}; + +constexpr int bar(int arg) { + return arg + 42; +} + +} // namespace ns + +template +struct Arg {}; + +template +struct Arg2 {}; + +template +struct Arg3 {}; + +template +struct Arg4 {}; + +template +struct Arg5 {}; + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constant(Arg<1>) {} + +// CHECK: void constant(Arg<1> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_v(Arg) {} + +// CHECK: void constexpr_v(Arg<2> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_expr(Arg) {} + +// CHECK: void constexpr_expr(Arg<6> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns(Arg) {} + +// CHECK: void constexpr_ns(Arg<4> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg) {} + +// CHECK: void constexpr_ns2(Arg<5> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg2) {} + +// CHECK: void constexpr_ns2(Arg2 ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg3) {} + +// CHECK: void constexpr_ns2(Arg3 ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg4) {} + +// CHECK: void constexpr_ns2(Arg4 ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg5) {} + +// CHECK: void constexpr_ns2(Arg5 ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_call(Arg) {} + +// CHECK: void constexpr_call(Arg<45> );