From 91dd52198cbc1d13e8419543603848a766917231 Mon Sep 17 00:00:00 2001 From: Alexey Sachkov Date: Wed, 24 Sep 2025 15:28:21 +0200 Subject: [PATCH 1/3] [SYCL] Allow free function kernel args be templated on integer expresions `constexpr` variables are not forward-declarable so if one is used as a template parameter of a free function kernel argument, we cannot reference the variable, but must inline the value into the integration header. --- clang/lib/Sema/SemaSYCL.cpp | 43 ++++++-- ...e-function-kernel-expr-as-template-arg.cpp | 101 ++++++++++++++++++ 2 files changed, 135 insertions(+), 9 deletions(-) create mode 100644 clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index b9a5d202af373..a9d3aa1011d85 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -6626,10 +6626,12 @@ class FreeFunctionPrinter { raw_ostream &O; PrintingPolicy &Policy; bool NSInserted = false; + ASTContext &Context; public: - FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy) - : O(O), Policy(PrintPolicy) {} + FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy, + ASTContext &Context) + : O(O), Policy(PrintPolicy), Context(Context) {} /// Emits the function declaration of template free function. /// \param FTD The function declaration to print. @@ -6826,18 +6828,41 @@ class FreeFunctionPrinter { CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream); ParmListOstream << "<"; - auto SpecArgs = TST->template_arguments(); - auto DeclArgs = CTST->template_arguments(); + ArrayRef SpecArgs = TST->template_arguments(); + ArrayRef DeclArgs = CTST->template_arguments(); + + auto TemplateArgPrinter = [&](const TemplateArgument &Arg) { + if (Arg.getKind() != TemplateArgument::ArgKind::Expression || + Arg.isInstantiationDependent()) { + Arg.print(Policy, ParmListOstream, false /* IncludeType */); + return; + } + + Expr *E = Arg.getAsExpr(); + if (E->getType().getTypePtr()->isScopedEnumeralType()) { + // Scoped enumerations can't be implicitly cast from integers, so + // we don't need to evaluate them + Arg.print(Policy, ParmListOstream, false /* IncludeType */); + return; + } + + Expr::EvalResult Res; + [[maybe_unused]] bool Success = + Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); + assert(Success && "invalid non-type template argument?"); + assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); + Res.Val.printPretty(ParmListOstream, Policy, + Arg.getAsExpr()->getType(), &Context); + }; for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), SE = SpecArgs.size(); I < E; ++I) { if (I != 0) ParmListOstream << ", "; - if (I < SE) // A specialized argument exists, use it - SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */); - else // Print a canonical form of a default argument - DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */); + // If we have a specialized argument, use it. Otherwise fallback to a + // default argument + TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]); } ParmListOstream << ">"; @@ -7236,7 +7261,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { // template arguments that match default template arguments while printing // template-ids, even if the source code doesn't reference them. Policy.EnforceDefaultTemplateArgs = true; - FreeFunctionPrinter FFPrinter(O, Policy); + FreeFunctionPrinter FFPrinter(O, Policy, S.getASTContext()); if (FTD) { FFPrinter.printFreeFunctionDeclaration(FTD); if (const auto kind = K.SyclKernel->getTemplateSpecializationKind(); diff --git a/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp b/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp new file mode 100644 index 0000000000000..69820a6e65b20 --- /dev/null +++ b/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp @@ -0,0 +1,101 @@ +// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s +// RUN: FileCheck -input-file=%t.h %s +// +// The purpose of this test is to ensure that forward declarations of free +// function kernels are emitted properly. +// However, this test checks a specific scenario: +// - free function argument is a template which accepts constant expressions as +// arguments + +constexpr int A = 2; +constexpr int B = 3; + +namespace ns { + +constexpr int C = 4; + +struct Foo { + static constexpr int D = 5; +}; + +enum non_class_enum { + VAL_A, + VAL_B +}; + +enum class class_enum { + VAL_A, + VAL_B +}; + +enum non_class_enum_typed : int { + VAL_C, + VAL_D +}; + +enum class class_enum_typed : int { + VAL_C, + VAL_D +}; + +} // namespace ns + +template +struct Arg {}; + +template +struct Arg2 {}; + +template +struct Arg3 {}; + +template +struct Arg4 {}; + +template +struct Arg5 {}; + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constant(Arg<1>) {} + +// CHECK: void constant(Arg<1> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_v(Arg) {} + +// CHECK: void constexpr_v(Arg<2> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_expr(Arg) {} + +// CHECK: void constexpr_expr(Arg<6> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns(Arg) {} + +// CHECK: void constexpr_ns(Arg<4> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg) {} + +// CHECK: void constexpr_ns2(Arg<5> ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg2) {} + +// CHECK: void constexpr_ns2(Arg2 ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg3) {} + +// CHECK: void constexpr_ns2(Arg3 ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg4) {} + +// CHECK: void constexpr_ns2(Arg4 ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_ns2(Arg5) {} + +// CHECK: void constexpr_ns2(Arg5 ); From 39b03937365b8d002e8531c7b6182bbaef375741 Mon Sep 17 00:00:00 2001 From: Alexey Sachkov Date: Wed, 24 Sep 2025 16:07:20 +0200 Subject: [PATCH 2/3] Apply clang-format --- clang/lib/Sema/SemaSYCL.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index a9d3aa1011d85..d8d5d3bcccc30 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -6851,8 +6851,8 @@ class FreeFunctionPrinter { Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); assert(Success && "invalid non-type template argument?"); assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); - Res.Val.printPretty(ParmListOstream, Policy, - Arg.getAsExpr()->getType(), &Context); + Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(), + &Context); }; for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), From e502a1b2adffe9cc92b0f512d33da299a92601f0 Mon Sep 17 00:00:00 2001 From: Alexey Sachkov Date: Wed, 24 Sep 2025 17:37:51 +0200 Subject: [PATCH 3/3] Apply comments --- clang/lib/Sema/SemaSYCL.cpp | 9 +++++---- .../free-function-kernel-expr-as-template-arg.cpp | 9 +++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index d8d5d3bcccc30..ffe1280c387f1 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -6834,15 +6834,16 @@ class FreeFunctionPrinter { auto TemplateArgPrinter = [&](const TemplateArgument &Arg) { if (Arg.getKind() != TemplateArgument::ArgKind::Expression || Arg.isInstantiationDependent()) { - Arg.print(Policy, ParmListOstream, false /* IncludeType */); + Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); return; } Expr *E = Arg.getAsExpr(); + assert(E && "Failed to get an Expr for an Expression template arg?"); if (E->getType().getTypePtr()->isScopedEnumeralType()) { // Scoped enumerations can't be implicitly cast from integers, so - // we don't need to evaluate them - Arg.print(Policy, ParmListOstream, false /* IncludeType */); + // we don't need to evaluate them. + Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); return; } @@ -6861,7 +6862,7 @@ class FreeFunctionPrinter { if (I != 0) ParmListOstream << ", "; // If we have a specialized argument, use it. Otherwise fallback to a - // default argument + // default argument. TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]); } diff --git a/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp b/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp index 69820a6e65b20..81120ef702d52 100644 --- a/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp +++ b/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp @@ -38,6 +38,10 @@ enum class class_enum_typed : int { VAL_D }; +constexpr int bar(int arg) { + return arg + 42; +} + } // namespace ns template @@ -99,3 +103,8 @@ void constexpr_ns2(Arg4) {} void constexpr_ns2(Arg5) {} // CHECK: void constexpr_ns2(Arg5 ); + +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void constexpr_call(Arg) {} + +// CHECK: void constexpr_call(Arg<45> );