Skip to content

Commit 6db442a

Browse files
[SYCL] Allow free function kernel args be templated on integer expressions (#20187)
`constexpr` variables are not forward-declarable so if one is used as a template parameter of a free function kernel argument, we cannot reference the variable, but must inline the value into the integration header.
1 parent 7675502 commit 6db442a

File tree

2 files changed

+145
-9
lines changed

2 files changed

+145
-9
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6626,10 +6626,12 @@ class FreeFunctionPrinter {
66266626
raw_ostream &O;
66276627
PrintingPolicy &Policy;
66286628
bool NSInserted = false;
6629+
ASTContext &Context;
66296630

66306631
public:
6631-
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy)
6632-
: O(O), Policy(PrintPolicy) {}
6632+
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy,
6633+
ASTContext &Context)
6634+
: O(O), Policy(PrintPolicy), Context(Context) {}
66336635

66346636
/// Emits the function declaration of template free function.
66356637
/// \param FTD The function declaration to print.
@@ -6826,18 +6828,42 @@ class FreeFunctionPrinter {
68266828
CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
68276829
ParmListOstream << "<";
68286830

6829-
auto SpecArgs = TST->template_arguments();
6830-
auto DeclArgs = CTST->template_arguments();
6831+
ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
6832+
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();
6833+
6834+
auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
6835+
if (Arg.getKind() != TemplateArgument::ArgKind::Expression ||
6836+
Arg.isInstantiationDependent()) {
6837+
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
6838+
return;
6839+
}
6840+
6841+
Expr *E = Arg.getAsExpr();
6842+
assert(E && "Failed to get an Expr for an Expression template arg?");
6843+
if (E->getType().getTypePtr()->isScopedEnumeralType()) {
6844+
// Scoped enumerations can't be implicitly cast from integers, so
6845+
// we don't need to evaluate them.
6846+
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
6847+
return;
6848+
}
6849+
6850+
Expr::EvalResult Res;
6851+
[[maybe_unused]] bool Success =
6852+
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
6853+
assert(Success && "invalid non-type template argument?");
6854+
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
6855+
Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(),
6856+
&Context);
6857+
};
68316858

68326859
for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
68336860
SE = SpecArgs.size();
68346861
I < E; ++I) {
68356862
if (I != 0)
68366863
ParmListOstream << ", ";
6837-
if (I < SE) // A specialized argument exists, use it
6838-
SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
6839-
else // Print a canonical form of a default argument
6840-
DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
6864+
// If we have a specialized argument, use it. Otherwise fallback to a
6865+
// default argument.
6866+
TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]);
68416867
}
68426868

68436869
ParmListOstream << ">";
@@ -7236,7 +7262,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
72367262
// template arguments that match default template arguments while printing
72377263
// template-ids, even if the source code doesn't reference them.
72387264
Policy.EnforceDefaultTemplateArgs = true;
7239-
FreeFunctionPrinter FFPrinter(O, Policy);
7265+
FreeFunctionPrinter FFPrinter(O, Policy, S.getASTContext());
72407266
if (FTD) {
72417267
FFPrinter.printFreeFunctionDeclaration(FTD);
72427268
if (const auto kind = K.SyclKernel->getTemplateSpecializationKind();
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
2+
// RUN: FileCheck -input-file=%t.h %s
3+
//
4+
// The purpose of this test is to ensure that forward declarations of free
5+
// function kernels are emitted properly.
6+
// However, this test checks a specific scenario:
7+
// - free function argument is a template which accepts constant expressions as
8+
// arguments
9+
10+
constexpr int A = 2;
11+
constexpr int B = 3;
12+
13+
namespace ns {
14+
15+
constexpr int C = 4;
16+
17+
struct Foo {
18+
static constexpr int D = 5;
19+
};
20+
21+
enum non_class_enum {
22+
VAL_A,
23+
VAL_B
24+
};
25+
26+
enum class class_enum {
27+
VAL_A,
28+
VAL_B
29+
};
30+
31+
enum non_class_enum_typed : int {
32+
VAL_C,
33+
VAL_D
34+
};
35+
36+
enum class class_enum_typed : int {
37+
VAL_C,
38+
VAL_D
39+
};
40+
41+
constexpr int bar(int arg) {
42+
return arg + 42;
43+
}
44+
45+
} // namespace ns
46+
47+
template<int V>
48+
struct Arg {};
49+
50+
template<ns::non_class_enum V>
51+
struct Arg2 {};
52+
53+
template<ns::non_class_enum_typed V>
54+
struct Arg3 {};
55+
56+
template<ns::class_enum V>
57+
struct Arg4 {};
58+
59+
template<ns::class_enum_typed V>
60+
struct Arg5 {};
61+
62+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
63+
void constant(Arg<1>) {}
64+
65+
// CHECK: void constant(Arg<1> );
66+
67+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
68+
void constexpr_v(Arg<A>) {}
69+
70+
// CHECK: void constexpr_v(Arg<2> );
71+
72+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
73+
void constexpr_expr(Arg<A * B>) {}
74+
75+
// CHECK: void constexpr_expr(Arg<6> );
76+
77+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
78+
void constexpr_ns(Arg<ns::C>) {}
79+
80+
// CHECK: void constexpr_ns(Arg<4> );
81+
82+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
83+
void constexpr_ns2(Arg<ns::Foo::D>) {}
84+
85+
// CHECK: void constexpr_ns2(Arg<5> );
86+
87+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
88+
void constexpr_ns2(Arg2<ns::non_class_enum::VAL_A>) {}
89+
90+
// CHECK: void constexpr_ns2(Arg2<ns::VAL_A> );
91+
92+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
93+
void constexpr_ns2(Arg3<ns::non_class_enum_typed::VAL_C>) {}
94+
95+
// CHECK: void constexpr_ns2(Arg3<ns::VAL_C> );
96+
97+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
98+
void constexpr_ns2(Arg4<ns::class_enum::VAL_A>) {}
99+
100+
// CHECK: void constexpr_ns2(Arg4<ns::class_enum::VAL_A> );
101+
102+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
103+
void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C>) {}
104+
105+
// CHECK: void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C> );
106+
107+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
108+
void constexpr_call(Arg<ns::bar(B)>) {}
109+
110+
// CHECK: void constexpr_call(Arg<45> );

0 commit comments

Comments
 (0)