Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6626,10 +6626,12 @@ class FreeFunctionPrinter {
raw_ostream &O;
PrintingPolicy &Policy;
bool NSInserted = false;
ASTContext &Context;

public:
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy)
: O(O), Policy(PrintPolicy) {}
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy,
ASTContext &Context)
: O(O), Policy(PrintPolicy), Context(Context) {}

/// Emits the function declaration of template free function.
/// \param FTD The function declaration to print.
Expand Down Expand Up @@ -6826,18 +6828,42 @@ class FreeFunctionPrinter {
CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
ParmListOstream << "<";

auto SpecArgs = TST->template_arguments();
auto DeclArgs = CTST->template_arguments();
ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();

auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
if (Arg.getKind() != TemplateArgument::ArgKind::Expression ||
Arg.isInstantiationDependent()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can Arg.isInstantiationDependent() ever be true in SemaSYCL?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it can. This extra check was prompted by

template <int ArrSize>                                                       
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel", 0)]]
void ff_7(KArgWithPtrArray<ArrSize> KArg) {                                  
  for (int j = 0; j < ArrSize; j++)                                          
    for (int i = KArg.start[j]; i <= KArg.end[j]; i++)                       
      KArg.data[j][i] = KArg.start[j] + KArg.end[j];                         
}                                                                            

(snippet from clang/test/CodeGenSYCL/free_function_int_header.cpp).

ArrSize template argument is instantiation dependent here

Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
return;
}

Expr *E = Arg.getAsExpr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can E end up being null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be because I check the argument kind above, but I added an assertion just in case (and to help static analyzers)

assert(E && "Failed to get an Expr for an Expression template arg?");
if (E->getType().getTypePtr()->isScopedEnumeralType()) {
// Scoped enumerations can't be implicitly cast from integers, so
// we don't need to evaluate them.
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
return;
}

Expr::EvalResult Res;
[[maybe_unused]] bool Success =
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
assert(Success && "invalid non-type template argument?");
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(),
&Context);
};

for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
SE = SpecArgs.size();
I < E; ++I) {
if (I != 0)
ParmListOstream << ", ";
if (I < SE) // A specialized argument exists, use it
SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
else // Print a canonical form of a default argument
DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */);
// If we have a specialized argument, use it. Otherwise fallback to a
// default argument.
TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]);
}

ParmListOstream << ">";
Expand Down Expand Up @@ -7236,7 +7262,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
// template arguments that match default template arguments while printing
// template-ids, even if the source code doesn't reference them.
Policy.EnforceDefaultTemplateArgs = true;
FreeFunctionPrinter FFPrinter(O, Policy);
FreeFunctionPrinter FFPrinter(O, Policy, S.getASTContext());
if (FTD) {
FFPrinter.printFreeFunctionDeclaration(FTD);
if (const auto kind = K.SyclKernel->getTemplateSpecializationKind();
Expand Down
110 changes: 110 additions & 0 deletions clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
// RUN: FileCheck -input-file=%t.h %s
//
// The purpose of this test is to ensure that forward declarations of free
// function kernels are emitted properly.
// However, this test checks a specific scenario:
// - free function argument is a template which accepts constant expressions as
// arguments

constexpr int A = 2;
constexpr int B = 3;

namespace ns {

constexpr int C = 4;

struct Foo {
static constexpr int D = 5;
};

enum non_class_enum {
VAL_A,
VAL_B
};

enum class class_enum {
VAL_A,
VAL_B
};

enum non_class_enum_typed : int {
VAL_C,
VAL_D
};

enum class class_enum_typed : int {
VAL_C,
VAL_D
};

constexpr int bar(int arg) {
return arg + 42;
}

} // namespace ns

template<int V>
struct Arg {};

template<ns::non_class_enum V>
struct Arg2 {};

template<ns::non_class_enum_typed V>
struct Arg3 {};

template<ns::class_enum V>
struct Arg4 {};

template<ns::class_enum_typed V>
struct Arg5 {};

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constant(Arg<1>) {}

// CHECK: void constant(Arg<1> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_v(Arg<A>) {}

// CHECK: void constexpr_v(Arg<2> );
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: we used to print Arg<A> here where A isn't forward-declared in the integration header (because we can't do for constexpr variables) which caused compilation errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about putting constexpr function calls there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added such test in e502a1b


[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_expr(Arg<A * B>) {}

// CHECK: void constexpr_expr(Arg<6> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns(Arg<ns::C>) {}

// CHECK: void constexpr_ns(Arg<4> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg<ns::Foo::D>) {}

// CHECK: void constexpr_ns2(Arg<5> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg2<ns::non_class_enum::VAL_A>) {}

// CHECK: void constexpr_ns2(Arg2<ns::VAL_A> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg3<ns::non_class_enum_typed::VAL_C>) {}

// CHECK: void constexpr_ns2(Arg3<ns::VAL_C> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg4<ns::class_enum::VAL_A>) {}

// CHECK: void constexpr_ns2(Arg4<ns::class_enum::VAL_A> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C>) {}

// CHECK: void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_call(Arg<ns::bar(B)>) {}

// CHECK: void constexpr_call(Arg<45> );