diff --git a/sycl/include/sycl/detail/common.hpp b/sycl/include/sycl/detail/common.hpp index 72249f860067c..8032d4d34c39e 100644 --- a/sycl/include/sycl/detail/common.hpp +++ b/sycl/include/sycl/detail/common.hpp @@ -226,6 +226,7 @@ namespace detail { // template // friend decltype(T::impl) detail::getSyclObjImpl(const T &SyclObject); template decltype(Obj::impl) getSyclObjImpl(const Obj &SyclObject) { + assert(SyclObject.impl && "every constructor should create an impl"); return SyclObject.impl; } diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 5a27567ce90ed..bcb6cc2fdefda 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -1290,6 +1290,8 @@ class __SYCL_EXPORT handler { std::shared_ptr getOrInsertHandlerKernelBundle(bool Insert) const; + void setHandlerKernelBundle(kernel Kernel); + void setHandlerKernelBundle( const std::shared_ptr &NewKernelBundleImpPtr); @@ -1918,7 +1920,7 @@ class __SYCL_EXPORT handler { throwIfActionIsCreated(); verifyKernelInvoc(Kernel); // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(detail::getSyclObjImpl(Kernel.get_kernel_bundle())); + setHandlerKernelBundle(Kernel); // No need to check if range is out of INT_MAX limits as it's compile-time // known constant MNDRDesc.set(range<1>{1}); @@ -1991,7 +1993,7 @@ class __SYCL_EXPORT handler { void single_task(kernel Kernel, _KERNELFUNCPARAM(KernelFunc)) { throwIfActionIsCreated(); // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(detail::getSyclObjImpl(Kernel.get_kernel_bundle())); + setHandlerKernelBundle(Kernel); using NameT = typename detail::get_kernel_name_t::name; verifyUsedKernelBundle(detail::KernelInfo::getName()); @@ -2037,7 +2039,7 @@ class __SYCL_EXPORT handler { _KERNELFUNCPARAM(KernelFunc)) { throwIfActionIsCreated(); // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(detail::getSyclObjImpl(Kernel.get_kernel_bundle())); + setHandlerKernelBundle(Kernel); using NameT = typename detail::get_kernel_name_t::name; verifyUsedKernelBundle(detail::KernelInfo::getName()); @@ -2075,7 +2077,7 @@ class __SYCL_EXPORT handler { id WorkItemOffset, _KERNELFUNCPARAM(KernelFunc)) { throwIfActionIsCreated(); // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(detail::getSyclObjImpl(Kernel.get_kernel_bundle())); + setHandlerKernelBundle(Kernel); using NameT = typename detail::get_kernel_name_t::name; verifyUsedKernelBundle(detail::KernelInfo::getName()); @@ -2113,7 +2115,7 @@ class __SYCL_EXPORT handler { _KERNELFUNCPARAM(KernelFunc)) { throwIfActionIsCreated(); // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(detail::getSyclObjImpl(Kernel.get_kernel_bundle())); + setHandlerKernelBundle(Kernel); using NameT = typename detail::get_kernel_name_t::name; verifyUsedKernelBundle(detail::KernelInfo::getName()); @@ -2155,7 +2157,7 @@ class __SYCL_EXPORT handler { _KERNELFUNCPARAM(KernelFunc)) { throwIfActionIsCreated(); // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(detail::getSyclObjImpl(Kernel.get_kernel_bundle())); + setHandlerKernelBundle(Kernel); using NameT = typename detail::get_kernel_name_t::name; verifyUsedKernelBundle(detail::KernelInfo::getName()); @@ -2195,7 +2197,7 @@ class __SYCL_EXPORT handler { _KERNELFUNCPARAM(KernelFunc)) { throwIfActionIsCreated(); // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(detail::getSyclObjImpl(Kernel.get_kernel_bundle())); + setHandlerKernelBundle(Kernel); using NameT = typename detail::get_kernel_name_t::name; verifyUsedKernelBundle(detail::KernelInfo::getName()); diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index ca3b6d6b06ae8..5b751590f403c 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -82,6 +82,15 @@ void handler::setHandlerKernelBundle( MImpl->MKernelBundle = NewKernelBundleImpPtr; } +void handler::setHandlerKernelBundle(kernel Kernel) { + // Kernel may not have an associated kernel bundle if it is created from a + // program. As such, apply getSyclObjImpl directly on the kernel, i.e. not + // the other way around: getSyclObjImp(Kernel->get_kernel_bundle()). + std::shared_ptr KernelBundleImpl = + detail::getSyclObjImpl(Kernel)->get_kernel_bundle(); + setHandlerKernelBundle(KernelBundleImpl); +} + event handler::finalize() { // This block of code is needed only for reduction implementation. // It is harmless (does nothing) for everything else. diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index 0536819dede58..8dfeecc077c44 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -3957,6 +3957,7 @@ _ZN4sycl3_V17handler18extractArgsAndReqsEv _ZN4sycl3_V17handler20DisableRangeRoundingEv _ZN4sycl3_V17handler20associateWithHandlerEPNS0_6detail16AccessorBaseHostENS0_6access6targetE _ZN4sycl3_V17handler20setStateSpecConstSetEv +_ZN4sycl3_V17handler22setHandlerKernelBundleENS0_6kernelE _ZN4sycl3_V17handler22setHandlerKernelBundleERKSt10shared_ptrINS0_6detail18kernel_bundle_implEE _ZN4sycl3_V17handler22verifyUsedKernelBundleERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE _ZN4sycl3_V17handler24GetRangeRoundingSettingsERmS2_S2_