Skip to content

Commit 93e7577

Browse files
[SYCL][6.3] Cherry-pick free function kernel fixes (#20236)
This is joined cherry-pick of #20187 and #20123 --- [SYCL] Allow free function kernel args be templated on integer expressions (#20187) `constexpr` variables are not forward-declarable so if one is used as a template parameter of a free function kernel argument, we cannot reference the variable, but must inline the value into the integration header. --- [SYCL] Fix error with type aliases used as free function kernel args (#20123) This PR fixes type name that is being printed as free function kernel argument type in its forward-declaration in the integration header. Before the change, we used the original argument type name, which could be an alias - this patch makes use of the canonical type's name to make sure that all type aliases are "unwrapped".
1 parent 3a6b088 commit 93e7577

File tree

3 files changed

+211
-12
lines changed

3 files changed

+211
-12
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6597,10 +6597,12 @@ class FreeFunctionPrinter {
65976597
raw_ostream &O;
65986598
PrintingPolicy &Policy;
65996599
bool NSInserted = false;
6600+
ASTContext &Context;
66006601

66016602
public:
6602-
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy)
6603-
: O(O), Policy(PrintPolicy) {}
6603+
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy,
6604+
ASTContext &Context)
6605+
: O(O), Policy(PrintPolicy), Context(Context) {}
66046606

66056607
/// Emits the function declaration of template free function.
66066608
/// \param FTD The function declaration to print.
@@ -6793,22 +6795,46 @@ class FreeFunctionPrinter {
67936795
continue;
67946796
}
67956797

6796-
TemplateName TN = TST->getTemplateName();
6797-
auto SpecArgs = TST->template_arguments();
6798-
auto DeclArgs = CTST->template_arguments();
6799-
6800-
TN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
6798+
TemplateName CTN = CTST->getTemplateName();
6799+
CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
68016800
ParmListOstream << "<";
68026801

6802+
ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
6803+
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();
6804+
6805+
auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
6806+
if (Arg.getKind() != TemplateArgument::ArgKind::Expression ||
6807+
Arg.isInstantiationDependent()) {
6808+
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
6809+
return;
6810+
}
6811+
6812+
Expr *E = Arg.getAsExpr();
6813+
assert(E && "Failed to get an Expr for an Expression template arg?");
6814+
if (E->getType().getTypePtr()->isScopedEnumeralType()) {
6815+
// Scoped enumerations can't be implicitly cast from integers, so
6816+
// we don't need to evaluate them.
6817+
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
6818+
return;
6819+
}
6820+
6821+
Expr::EvalResult Res;
6822+
[[maybe_unused]] bool Success =
6823+
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
6824+
assert(Success && "invalid non-type template argument?");
6825+
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
6826+
Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(),
6827+
&Context);
6828+
};
6829+
68036830
for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
68046831
SE = SpecArgs.size();
68056832
I < E; ++I) {
68066833
if (I != 0)
68076834
ParmListOstream << ", ";
6808-
if (I < SE) // A specialized argument exists, use it
6809-
SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
6810-
else // Print a canonical form of a default argument
6811-
DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
6835+
// If we have a specialized argument, use it. Otherwise fallback to a
6836+
// default argument.
6837+
TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]);
68126838
}
68136839

68146840
ParmListOstream << ">";
@@ -7207,7 +7233,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
72077233
// template arguments that match default template arguments while printing
72087234
// template-ids, even if the source code doesn't reference them.
72097235
Policy.EnforceDefaultTemplateArgs = true;
7210-
FreeFunctionPrinter FFPrinter(O, Policy);
7236+
FreeFunctionPrinter FFPrinter(O, Policy, S.getASTContext());
72117237
if (FTD) {
72127238
FFPrinter.printFreeFunctionDeclaration(FTD);
72137239
if (const auto kind = K.SyclKernel->getTemplateSpecializationKind();
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
2+
// RUN: FileCheck -input-file=%t.h %s
3+
//
4+
// The purpose of this test is to ensure that forward declarations of free
5+
// function kernels are emitted properly.
6+
// However, this test checks a specific scenario:
7+
// - free function argument is a template which accepts constant expressions as
8+
// arguments
9+
10+
constexpr int A = 2;
11+
constexpr int B = 3;
12+
13+
namespace ns {
14+
15+
constexpr int C = 4;
16+
17+
struct Foo {
18+
static constexpr int D = 5;
19+
};
20+
21+
enum non_class_enum {
22+
VAL_A,
23+
VAL_B
24+
};
25+
26+
enum class class_enum {
27+
VAL_A,
28+
VAL_B
29+
};
30+
31+
enum non_class_enum_typed : int {
32+
VAL_C,
33+
VAL_D
34+
};
35+
36+
enum class class_enum_typed : int {
37+
VAL_C,
38+
VAL_D
39+
};
40+
41+
constexpr int bar(int arg) {
42+
return arg + 42;
43+
}
44+
45+
} // namespace ns
46+
47+
template<int V>
48+
struct Arg {};
49+
50+
template<ns::non_class_enum V>
51+
struct Arg2 {};
52+
53+
template<ns::non_class_enum_typed V>
54+
struct Arg3 {};
55+
56+
template<ns::class_enum V>
57+
struct Arg4 {};
58+
59+
template<ns::class_enum_typed V>
60+
struct Arg5 {};
61+
62+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
63+
void constant(Arg<1>) {}
64+
65+
// CHECK: void constant(Arg<1> );
66+
67+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
68+
void constexpr_v(Arg<A>) {}
69+
70+
// CHECK: void constexpr_v(Arg<2> );
71+
72+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
73+
void constexpr_expr(Arg<A * B>) {}
74+
75+
// CHECK: void constexpr_expr(Arg<6> );
76+
77+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
78+
void constexpr_ns(Arg<ns::C>) {}
79+
80+
// CHECK: void constexpr_ns(Arg<4> );
81+
82+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
83+
void constexpr_ns2(Arg<ns::Foo::D>) {}
84+
85+
// CHECK: void constexpr_ns2(Arg<5> );
86+
87+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
88+
void constexpr_ns2(Arg2<ns::non_class_enum::VAL_A>) {}
89+
90+
// CHECK: void constexpr_ns2(Arg2<ns::VAL_A> );
91+
92+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
93+
void constexpr_ns2(Arg3<ns::non_class_enum_typed::VAL_C>) {}
94+
95+
// CHECK: void constexpr_ns2(Arg3<ns::VAL_C> );
96+
97+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
98+
void constexpr_ns2(Arg4<ns::class_enum::VAL_A>) {}
99+
100+
// CHECK: void constexpr_ns2(Arg4<ns::class_enum::VAL_A> );
101+
102+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
103+
void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C>) {}
104+
105+
// CHECK: void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C> );
106+
107+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
108+
void constexpr_call(Arg<ns::bar(B)>) {}
109+
110+
// CHECK: void constexpr_call(Arg<45> );
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
2+
// RUN: FileCheck -input-file=%t.h %s
3+
//
4+
// The purpose of this test is to ensure that forward declarations of free
5+
// function kernels are emitted properly.
6+
// However, this test checks a specific scenario:
7+
// - free function arguments are type aliases (through using or typedef)
8+
9+
namespace ns {
10+
11+
using IntUsing = int;
12+
typedef int IntTypedef;
13+
14+
template <typename T>
15+
struct Foo {};
16+
17+
using FooIntUsing = Foo<int>;
18+
typedef Foo<int> FooIntTypedef;
19+
20+
template <typename T1, typename T2>
21+
struct Bar {};
22+
23+
template<typename T1>
24+
using BarUsing = Bar<T1, float>;
25+
26+
class Baz {
27+
public:
28+
using type = BarUsing<double>;
29+
};
30+
31+
} // namespace ns
32+
33+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
34+
void int_using(ns::IntUsing Arg) {}
35+
36+
// CHECK: void int_using(int Arg);
37+
38+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
39+
void int_typedef(ns::IntTypedef Arg) {}
40+
41+
// CHECK: void int_typedef(int Arg);
42+
43+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
44+
void foo_using(ns::FooIntUsing Arg) {}
45+
46+
// CHECK: void foo_using(ns::Foo<int> Arg);
47+
48+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
49+
void foo_typedef(ns::FooIntTypedef Arg) {}
50+
51+
// CHECK: void foo_typedef(ns::Foo<int> Arg);
52+
53+
template<typename T>
54+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
55+
void bar_using(ns::BarUsing<T> Arg) {}
56+
template void bar_using(ns::BarUsing<int>);
57+
58+
// CHECK: template <typename T> void bar_using(ns::Bar<T, float>);
59+
60+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
61+
void baz_type(ns::Baz::type Arg) {}
62+
63+
// CHECK: void baz_type(ns::Bar<double, float> Arg);

0 commit comments

Comments
 (0)