Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6597,10 +6597,12 @@ class FreeFunctionPrinter {
raw_ostream &O;
PrintingPolicy &Policy;
bool NSInserted = false;
ASTContext &Context;

public:
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy)
: O(O), Policy(PrintPolicy) {}
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy,
ASTContext &Context)
: O(O), Policy(PrintPolicy), Context(Context) {}

/// Emits the function declaration of template free function.
/// \param FTD The function declaration to print.
Expand Down Expand Up @@ -6793,22 +6795,46 @@ class FreeFunctionPrinter {
continue;
}

TemplateName TN = TST->getTemplateName();
auto SpecArgs = TST->template_arguments();
auto DeclArgs = CTST->template_arguments();

TN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
TemplateName CTN = CTST->getTemplateName();
CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
ParmListOstream << "<";

ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();

auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
if (Arg.getKind() != TemplateArgument::ArgKind::Expression ||
Arg.isInstantiationDependent()) {
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
return;
}

Expr *E = Arg.getAsExpr();
assert(E && "Failed to get an Expr for an Expression template arg?");
if (E->getType().getTypePtr()->isScopedEnumeralType()) {
// Scoped enumerations can't be implicitly cast from integers, so
// we don't need to evaluate them.
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
return;
}

Expr::EvalResult Res;
[[maybe_unused]] bool Success =
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
assert(Success && "invalid non-type template argument?");
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(),
&Context);
};

for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
SE = SpecArgs.size();
I < E; ++I) {
if (I != 0)
ParmListOstream << ", ";
if (I < SE) // A specialized argument exists, use it
SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
else // Print a canonical form of a default argument
DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
// If we have a specialized argument, use it. Otherwise fallback to a
// default argument.
TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]);
}

ParmListOstream << ">";
Expand Down Expand Up @@ -7207,7 +7233,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
// template arguments that match default template arguments while printing
// template-ids, even if the source code doesn't reference them.
Policy.EnforceDefaultTemplateArgs = true;
FreeFunctionPrinter FFPrinter(O, Policy);
FreeFunctionPrinter FFPrinter(O, Policy, S.getASTContext());
if (FTD) {
FFPrinter.printFreeFunctionDeclaration(FTD);
if (const auto kind = K.SyclKernel->getTemplateSpecializationKind();
Expand Down
110 changes: 110 additions & 0 deletions clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
// RUN: FileCheck -input-file=%t.h %s
//
// The purpose of this test is to ensure that forward declarations of free
// function kernels are emitted properly.
// However, this test checks a specific scenario:
// - free function argument is a template which accepts constant expressions as
// arguments

constexpr int A = 2;
constexpr int B = 3;

namespace ns {

constexpr int C = 4;

struct Foo {
static constexpr int D = 5;
};

enum non_class_enum {
VAL_A,
VAL_B
};

enum class class_enum {
VAL_A,
VAL_B
};

enum non_class_enum_typed : int {
VAL_C,
VAL_D
};

enum class class_enum_typed : int {
VAL_C,
VAL_D
};

constexpr int bar(int arg) {
return arg + 42;
}

} // namespace ns

template<int V>
struct Arg {};

template<ns::non_class_enum V>
struct Arg2 {};

template<ns::non_class_enum_typed V>
struct Arg3 {};

template<ns::class_enum V>
struct Arg4 {};

template<ns::class_enum_typed V>
struct Arg5 {};

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constant(Arg<1>) {}

// CHECK: void constant(Arg<1> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_v(Arg<A>) {}

// CHECK: void constexpr_v(Arg<2> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_expr(Arg<A * B>) {}

// CHECK: void constexpr_expr(Arg<6> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns(Arg<ns::C>) {}

// CHECK: void constexpr_ns(Arg<4> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg<ns::Foo::D>) {}

// CHECK: void constexpr_ns2(Arg<5> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg2<ns::non_class_enum::VAL_A>) {}

// CHECK: void constexpr_ns2(Arg2<ns::VAL_A> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg3<ns::non_class_enum_typed::VAL_C>) {}

// CHECK: void constexpr_ns2(Arg3<ns::VAL_C> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg4<ns::class_enum::VAL_A>) {}

// CHECK: void constexpr_ns2(Arg4<ns::class_enum::VAL_A> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C>) {}

// CHECK: void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_call(Arg<ns::bar(B)>) {}

// CHECK: void constexpr_call(Arg<45> );
63 changes: 63 additions & 0 deletions clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
// RUN: FileCheck -input-file=%t.h %s
//
// The purpose of this test is to ensure that forward declarations of free
// function kernels are emitted properly.
// However, this test checks a specific scenario:
// - free function arguments are type aliases (through using or typedef)

namespace ns {

using IntUsing = int;
typedef int IntTypedef;

template <typename T>
struct Foo {};

using FooIntUsing = Foo<int>;
typedef Foo<int> FooIntTypedef;

template <typename T1, typename T2>
struct Bar {};

template<typename T1>
using BarUsing = Bar<T1, float>;

class Baz {
public:
using type = BarUsing<double>;
};

} // namespace ns

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void int_using(ns::IntUsing Arg) {}

// CHECK: void int_using(int Arg);

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void int_typedef(ns::IntTypedef Arg) {}

// CHECK: void int_typedef(int Arg);

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void foo_using(ns::FooIntUsing Arg) {}

// CHECK: void foo_using(ns::Foo<int> Arg);

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void foo_typedef(ns::FooIntTypedef Arg) {}

// CHECK: void foo_typedef(ns::Foo<int> Arg);

template<typename T>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void bar_using(ns::BarUsing<T> Arg) {}
template void bar_using(ns::BarUsing<int>);

// CHECK: template <typename T> void bar_using(ns::Bar<T, float>);

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void baz_type(ns::Baz::type Arg) {}

// CHECK: void baz_type(ns::Bar<double, float> Arg);