diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index f983582ed2085..8b41b68a5798e 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -18,6 +18,7 @@ #include "clang/AST/SYCLKernelInfo.h" #include "clang/AST/StmtSYCL.h" #include "clang/AST/TemplateArgumentVisitor.h" +#include "clang/AST/Type.h" #include "clang/AST/TypeOrdering.h" #include "clang/AST/TypeVisitor.h" #include "clang/Analysis/CallGraph.h" @@ -6738,9 +6739,9 @@ class FreeFunctionPrinter { /// returned string Example: /// \code /// template - /// void foo(T1 a, T2 b); + /// void foo(T1 a, int b, T2 c); /// \endcode - /// returns string "T1 a, T2 b" + /// returns string "T1, int, T2" std::string getTemplatedParamList(const llvm::ArrayRef Parameters, PrintingPolicy Policy) { @@ -6748,13 +6749,65 @@ class FreeFunctionPrinter { llvm::SmallString<128> ParamList; llvm::raw_svector_ostream ParmListOstream{ParamList}; Policy.SuppressTagKeyword = true; + for (ParmVarDecl *Param : Parameters) { if (FirstParam) FirstParam = false; else ParmListOstream << ", "; - ParmListOstream << Param->getType().getAsString(Policy); - ParmListOstream << " " << Param->getNameAsString(); + + // There are cases when we can't directly use neither the original + // argument type, nor its canonical version. An example would be: + // template + // void kernel(sycl::accessor); + // template void kernel(sycl::accessor); + // Accessor has multiple non-type template arguments with default values + // and non-qualified type will not include necessary namespaces for all + // of them. Qualified type will have that information, but all references + // to T will be replaced to something like type-argument-0 + // What we do instead is we iterate template arguments of both versions + // of a type in sync and take elements from one or another to get the best + // of both: proper references to template arguments of a kernel itself and + // fully-qualified names for enumerations. + // + // Moral of the story: drop integration header ASAP (but that is blocked + // by support for 3rd-party host compilers, which is important). + QualType T = Param->getType(); + QualType CT = T.getCanonicalType(); + + auto *ET = dyn_cast(T.getTypePtr()); + if (!ET) { + ParmListOstream << T.getAsString(Policy); + continue; + } + + auto *TST = + dyn_cast(ET->getNamedType().getTypePtr()); + auto *CTST = dyn_cast(CT.getTypePtr()); + if (!TST || !CTST) { + ParmListOstream << T.getAsString(Policy); + continue; + } + + TemplateName TN = TST->getTemplateName(); + auto SpecArgs = TST->template_arguments(); + auto DeclArgs = CTST->template_arguments(); + + TN.getAsTemplateDecl()->printQualifiedName(ParmListOstream); + ParmListOstream << "<"; + + for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), + SE = SpecArgs.size(); + I < E; ++I) { + if (I != 0) + ParmListOstream << ", "; + if (I < SE) // A specialized argument exists, use it + SpecArgs[I].print(Policy, ParmListOstream, false /* IncludeType */); + else // Print a canonical form of a default argument + DeclArgs[I].print(Policy, ParmListOstream, false /* IncludeType */); + } + + ParmListOstream << ">"; } return ParamList.str().str(); } diff --git a/clang/test/CodeGenSYCL/free-function-kernel-templated-arg-with-enum.cpp b/clang/test/CodeGenSYCL/free-function-kernel-templated-arg-with-enum.cpp new file mode 100644 index 0000000000000..bc45a922bd3b5 --- /dev/null +++ b/clang/test/CodeGenSYCL/free-function-kernel-templated-arg-with-enum.cpp @@ -0,0 +1,54 @@ +// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s +// RUN: FileCheck -input-file=%t.h %s +// +// The purpose of this test is to ensure that forward declarations of free +// function kernels are emitted properly. +// However, this test checks a specific scenario: +// - free function kernel is a function template +// - its argument is templated and has non-type template parameter (with default +// value) that is an enumeration defined within a namespace + +namespace ns { + +enum class enum_A { A, B, C }; + +template +class feature_A {}; + +namespace nested { +enum class enum_B { A, B, C }; + +template +struct feature_B {}; +} + +inline namespace nested_inline { +namespace nested2 { +enum class enum_C { A, B, C }; + +template +struct feature_C {}; +} +} // namespace nested_inline +} // namespace ns + +template +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void templated_on_A(ns::feature_A Arg) {} +template void templated_on_A(ns::feature_A); + +// CHECK: template void templated_on_A(ns::feature_A); + +template +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void templated_on_B(ns::nested::feature_B Arg) {} +template void templated_on_B(ns::nested::feature_B); + +// CHECK: template void templated_on_B(ns::nested::feature_B); + +template +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void templated_on_C(ns::nested2::feature_C Arg) {} +template void templated_on_C(ns::nested2::feature_C<42>); + +// CHECK: template void templated_on_C(ns::nested2::feature_C); diff --git a/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp b/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp index 2a9187d618b24..e660bd3070874 100644 --- a/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp +++ b/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp @@ -295,7 +295,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // namespace _V1 // CHECK-NEXT: } // namespace sycl -// CHECK: template void templated(ns::Arg , T end); +// CHECK: template void templated(ns::Arg>, T); // CHECK-NEXT: static constexpr auto __sycl_shim3() { // CHECK-NEXT: return (void (*)(struct ns::Arg, int))templated; // CHECK-NEXT: } @@ -314,7 +314,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // namespace _V1 // CHECK-NEXT: } // namespace sycl -// CHECK: template void templated2(ns::Arg , T end); +// CHECK: template void templated2(ns::Arg>, T); // CHECK-NEXT: static constexpr auto __sycl_shim4() { // CHECK-NEXT: return (void (*)(struct ns::Arg, int))templated2; // CHECK-NEXT: } @@ -333,7 +333,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // namespace _V1 // CHECK-NEXT: } // namespace sycl -// CHECK: template void templated3(ns::Arg, int, int> , T end); +// CHECK: template void templated3(ns::Arg, int, int>, T); // CHECK-NEXT: static constexpr auto __sycl_shim5() { // CHECK-NEXT: return (void (*)(struct ns::Arg, int, int>, int))templated3; // CHECK-NEXT: } @@ -352,7 +352,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // namespace _V1 // CHECK-NEXT: } // namespace sycl -// CHECK: template void templated3(ns::Arg, int, int> , T end); +// CHECK: template void templated3(ns::Arg, int, int>, T); // CHECK-NEXT: static constexpr auto __sycl_shim6() { // CHECK-NEXT: return (void (*)(struct ns::Arg, int, int>, float))templated3; // CHECK-NEXT: } @@ -400,7 +400,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // namespace sycl // CHECK: namespace TestNamespace { -// CHECK-NEXT: template void templated(ns::Arg , T end); +// CHECK-NEXT: template void templated(ns::Arg>, T); // CHECK-NEXT: } // namespace TestNamespace // CHECK: static constexpr auto __sycl_shim8() { @@ -434,7 +434,7 @@ namespace Testing::Tests { // CHECK: namespace TestNamespace { // CHECK-NEXT: inline namespace _V1 { -// CHECK-NEXT: template void templated1(ns::Arg , T end); +// CHECK-NEXT: template void templated1(ns::Arg>, T); // CHECK-NEXT: } // inline namespace _V1 // CHECK-NEXT: } // namespace TestNamespace // CHECK: static constexpr auto __sycl_shim9() { @@ -468,7 +468,7 @@ namespace Testing::Tests { // CHECK: namespace TestNamespace { // CHECK-NEXT: inline namespace _V2 { -// CHECK-NEXT: template void templated1(ns::Arg , T end); +// CHECK-NEXT: template void templated1(ns::Arg>, T); // CHECK-NEXT: } // inline namespace _V2 // CHECK-NEXT: } // namespace TestNamespace // CHECK: static constexpr auto __sycl_shim10() { @@ -501,7 +501,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim11() { // CHECK-NEXT: return (void (*)(float, float))templated; @@ -533,7 +533,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // CHECK: struct TestStruct; -// CHECK: template void templated(ns::Arg , T end); +// CHECK: template void templated(ns::Arg>, T); // CHECK-NEXT: static constexpr auto __sycl_shim12() { // CHECK-NEXT: return (void (*)(struct ns::Arg, struct TestStruct))templated; // CHECK-NEXT:} @@ -565,7 +565,7 @@ namespace Testing::Tests { // CHECK: class BaseClass; // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim13() { // CHECK-NEXT: return (void (*)(class BaseClass, class BaseClass))templated; @@ -598,7 +598,7 @@ namespace Testing::Tests { // CHECK: class ChildOne; // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim14() { // CHECK-NEXT: return (void (*)(class ChildOne, class ChildOne))templated; @@ -631,7 +631,7 @@ namespace Testing::Tests { // CHECK: class ChildTwo; // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim15() { // CHECK-NEXT: return (void (*)(class ChildTwo, class ChildTwo))templated; @@ -664,7 +664,7 @@ namespace Testing::Tests { // CHECK: class ChildThree; // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim16() { // CHECK-NEXT: return (void (*)(class ChildThree, class ChildThree))templated; @@ -699,7 +699,7 @@ namespace Testing::Tests { // CHECK-NEXT: template struct id; // CHECK-NEXT: }} // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim17() { // CHECK-NEXT: return (void (*)(struct sycl::id<2>, struct sycl::id<2>))templated>; @@ -734,7 +734,7 @@ namespace Testing::Tests { // CHECK-NEXT: template struct range; // CHECK-NEXT: }} // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim18() { // CHECK-NEXT: return (void (*)(struct sycl::range<3>, struct sycl::range<3>))templated>; @@ -766,7 +766,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim19() { // CHECK-NEXT: return (void (*)(int *, int *))templated; @@ -798,7 +798,7 @@ namespace Testing::Tests { // CHECK-NEXT: } // CHECK: namespace { -// CHECK-NEXT: template void templated(T start, T end); +// CHECK-NEXT: template void templated(T, T); // CHECK-NEXT: } // namespace // CHECK: static constexpr auto __sycl_shim20() { // CHECK-NEXT: return (void (*)(struct sycl::X, struct sycl::X))templated>; @@ -835,7 +835,7 @@ namespace Testing::Tests { // CHECK-NEXT: }}} // CHECK: namespace TestNamespace { // CHECK-NEXT: inline namespace _V1 { -// CHECK-NEXT: template void templated1(ns::Arg , T end); +// CHECK-NEXT: template void templated1(ns::Arg>, T); // CHECK-NEXT: } // inline namespace _V1 // CHECK-NEXT: } // namespace TestNamespace // CHECK: static constexpr auto __sycl_shim21() { @@ -867,7 +867,7 @@ namespace Testing::Tests { // CHECK-NEXT: }; // CHECK-NEXT: } -// CHECK: template void variadic_templated(Args... args); +// CHECK: template void variadic_templated(Args...); // CHECK-NEXT: static constexpr auto __sycl_shim22() { // CHECK-NEXT: return (void (*)(int, float, char))variadic_templated; // CHECK-NEXT: } @@ -897,7 +897,7 @@ namespace Testing::Tests { // CHECK-NEXT: }; // CHECK-NEXT: } -// CHECK: template void variadic_templated(Args... args); +// CHECK: template void variadic_templated(Args...); // CHECK-NEXT: static constexpr auto __sycl_shim23() { // CHECK-NEXT: return (void (*)(int, float, char, int))variadic_templated; // CHECK-NEXT: } @@ -927,7 +927,7 @@ namespace Testing::Tests { // CHECK-NEXT: }; // CHECK-NEXT: } -// CHECK: template void variadic_templated(Args... args); +// CHECK: template void variadic_templated(Args...); // CHECK-NEXT: static constexpr auto __sycl_shim24() { // CHECK-NEXT: return (void (*)(float, float))variadic_templated; // CHECK-NEXT: } @@ -957,7 +957,7 @@ namespace Testing::Tests { // CHECK-NEXT: }; // CHECK-NEXT: } -// CHECK: template void variadic_templated1(T b, Args... args); +// CHECK: template void variadic_templated1(T, Args...); // CHECK-NEXT: static constexpr auto __sycl_shim25() { // CHECK-NEXT: return (void (*)(float, char, char))variadic_templated1; // CHECK-NEXT: } @@ -987,7 +987,7 @@ namespace Testing::Tests { // CHECK-NEXT: }; // CHECK-NEXT: } -// CHECK: template void variadic_templated1(T b, Args... args); +// CHECK: template void variadic_templated1(T, Args...); // CHECK-NEXT: static constexpr auto __sycl_shim26() { // CHECK-NEXT: return (void (*)(int, float, char))variadic_templated1; // CHECK-NEXT: } @@ -1019,7 +1019,7 @@ namespace Testing::Tests { // CHECK: namespace Testing { // CHECK-NEXT: namespace Tests { -// CHECK-NEXT: template void variadic_templated(T b, Args... args); +// CHECK-NEXT: template void variadic_templated(T, Args...); // CHECK-NEXT: } // namespace Tests // CHECK-NEXT: } // namespace Testing // CHECK: static constexpr auto __sycl_shim27() { @@ -1053,7 +1053,7 @@ namespace Testing::Tests { // CHECK: namespace Testing { // CHECK-NEXT: namespace Tests { -// CHECK-NEXT: template void variadic_templated(T b, Args... args); +// CHECK-NEXT: template void variadic_templated(T, Args...); // CHECK-NEXT: } // namespace Tests // CHECK-NEXT: } // namespace Testing // CHECK: static constexpr auto __sycl_shim28() { diff --git a/clang/test/CodeGenSYCL/free_function_int_header.cpp b/clang/test/CodeGenSYCL/free_function_int_header.cpp index b05c299a2e478..4fe7a761e98c6 100644 --- a/clang/test/CodeGenSYCL/free_function_int_header.cpp +++ b/clang/test/CodeGenSYCL/free_function_int_header.cpp @@ -508,7 +508,7 @@ void ff_24(int arg) { // CHECK: Definition of _Z18__sycl_kernel_ff_3IiEvPT_S0_S0_ as a free function kernel // CHECK: Forward declarations of kernel and its argument types: -// CHECK: template void ff_3(T * ptr, T start, T end); +// CHECK: template void ff_3(T *, T, T); // CHECK-NEXT: static constexpr auto __sycl_shim3() { // CHECK-NEXT: return (void (*)(int *, int, int))ff_3; // CHECK-NEXT: } @@ -540,7 +540,7 @@ void ff_24(int arg) { // CHECK: Definition of _Z18__sycl_kernel_ff_3IfEvPT_S0_S0_ as a free function kernel // CHECK: Forward declarations of kernel and its argument types: -// CHECK: template void ff_3(T * ptr, T start, T end); +// CHECK: template void ff_3(T *, T, T); // CHECK-NEXT: static constexpr auto __sycl_shim4() { // CHECK-NEXT: return (void (*)(float *, float, float))ff_3; // CHECK-NEXT: } @@ -572,7 +572,7 @@ void ff_24(int arg) { // CHECK: Definition of _Z18__sycl_kernel_ff_3IdEvPT_S0_S0_ as a free function kernel // CHECK: Forward declarations of kernel and its argument types: -// CHECK: template void ff_3(T * ptr, T start, T end); +// CHECK: template void ff_3(T *, T, T); // CHECK: template <> void ff_3(double * ptr, double start, double end); // CHECK-NEXT: static constexpr auto __sycl_shim5() { // CHECK-NEXT: return (void (*)(double *, double, double))ff_3; @@ -641,7 +641,7 @@ void ff_24(int arg) { // CHECK: Definition of _Z18__sycl_kernel_ff_6I3Agg7DerivedEvT_T0_i as a free function kernel // CHECK: Forward declarations of kernel and its argument types: // CHECK: struct Derived; -// CHECK: template void ff_6(T1 S1, T2 S2, int end); +// CHECK: template void ff_6(T1, T2, int); // CHECK-NEXT: static constexpr auto __sycl_shim7() { // CHECK-NEXT: return (void (*)(struct Agg, struct Derived, int))ff_6; // CHECK-NEXT: } @@ -676,7 +676,7 @@ void ff_24(int arg) { // CHECK: Forward declarations of kernel and its argument types: // CHECK: template struct KArgWithPtrArray; // -// CHECK: template void ff_7(KArgWithPtrArray KArg); +// CHECK: template void ff_7(KArgWithPtrArray); // CHECK-NEXT: static constexpr auto __sycl_shim8() { // CHECK-NEXT: return (void (*)(struct KArgWithPtrArray<3>))ff_7<3>; // CHECK-NEXT: } @@ -1021,7 +1021,7 @@ void ff_24(int arg) { // CHECK: Forward declarations of kernel and its argument types: -// CHECK: template void ff_11(sycl::local_accessor lacc); +// CHECK: template void ff_11(sycl::local_accessor); // CHECK-NEXT: static constexpr auto __sycl_shim // CHECK-NEXT: return (void (*)(class sycl::local_accessor))ff_11; diff --git a/sycl/test-e2e/FreeFunctionKernels/accessor_as_kernel_parameter.cpp b/sycl/test-e2e/FreeFunctionKernels/accessor_as_kernel_parameter.cpp index 260c7d9b0203c..cd4bcfa532451 100644 --- a/sycl/test-e2e/FreeFunctionKernels/accessor_as_kernel_parameter.cpp +++ b/sycl/test-e2e/FreeFunctionKernels/accessor_as_kernel_parameter.cpp @@ -10,23 +10,16 @@ template SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::single_task_kernel)) -void globalScopeSingleFreeFunc( - sycl::accessor - Accessor, - int Value) { +void globalScopeSingleFreeFunc(sycl::accessor Accessor, int Value) { for (auto &Elem : Accessor) Elem = Value; } namespace ns { template SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel)) -void nsNdRangeFreeFunc(sycl::accessor - Accessor, - int Value) { +void nsNdRangeFreeFunc( + sycl::accessor Accessor, + int Value) { auto Item = syclext::this_work_item::get_nd_item().get_global_id(); Accessor[Item] = Value; } @@ -36,17 +29,10 @@ template SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel)) void ndRangeFreeFuncMultipleParameters( sycl::accessor + sycl::access::target::device> InputAAcc, - sycl::accessor - InputBAcc, - sycl::accessor - ResultAcc) { + sycl::accessor InputBAcc, + sycl::accessor ResultAcc) { auto Item = syclext::this_work_item::get_nd_item().get_global_id(); ResultAcc[Item] = InputAAcc[Item] + InputBAcc[Item]; } diff --git a/sycl/test-e2e/FreeFunctionKernels/template_specialization.cpp b/sycl/test-e2e/FreeFunctionKernels/template_specialization.cpp index 94e5ccc1f858c..65ce6daaa0d32 100644 --- a/sycl/test-e2e/FreeFunctionKernels/template_specialization.cpp +++ b/sycl/test-e2e/FreeFunctionKernels/template_specialization.cpp @@ -53,6 +53,13 @@ SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( (ext::oneapi::experimental::nd_range_kernel<1>)) void sum1(T arg) {} +template <> +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( + (ext::oneapi::experimental::nd_range_kernel<1>)) +void sum1<3, sycl::accessor>(sycl::accessor arg) { + arg[0] = 42; +} + template <> SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( (ext::oneapi::experimental::nd_range_kernel<1>)) @@ -137,6 +144,9 @@ void test_accessor() { h.set_args(acc); h.parallel_for(nd_range{{1}, {1}}, Kernel); }); + + auto acc = buf.get_host_access(); + assert(acc[0] == 42); } void test_shared() {