diff --git a/sycl/include/CL/sycl/detail/device_host.hpp b/sycl/include/CL/sycl/detail/device_host.hpp index 8f17c4c880b6e..237e9f6ea9a8c 100644 --- a/sycl/include/CL/sycl/detail/device_host.hpp +++ b/sycl/include/CL/sycl/detail/device_host.hpp @@ -24,6 +24,9 @@ class device_host : public device_impl { cl_device_id get() const override { throw invalid_object_error("This instance of device is a host instance"); } + cl_device_id &getHandleRef() override { + throw invalid_object_error("This instance of device is a host instance"); + } bool is_host() const override { return true; } diff --git a/sycl/include/CL/sycl/detail/device_impl.hpp b/sycl/include/CL/sycl/detail/device_impl.hpp index 1238ce058d976..0b3721908fee2 100644 --- a/sycl/include/CL/sycl/detail/device_impl.hpp +++ b/sycl/include/CL/sycl/detail/device_impl.hpp @@ -29,6 +29,12 @@ class device_impl { virtual cl_device_id get() const = 0; + // Returns underlying native device object (if any) w/o reference count + // modification. Caller must ensure the returned object lives on stack only. + // It can also be safely passed to the underlying native runtime API. + // Warning. Returned reference will be invalid if device_impl was destroyed. + virtual cl_device_id &getHandleRef() = 0; + virtual bool is_host() const = 0; virtual bool is_cpu() const = 0; diff --git a/sycl/include/CL/sycl/detail/device_opencl.hpp b/sycl/include/CL/sycl/detail/device_opencl.hpp index 31808fb3d3a9b..a31f41b005ec1 100644 --- a/sycl/include/CL/sycl/detail/device_opencl.hpp +++ b/sycl/include/CL/sycl/detail/device_opencl.hpp @@ -57,6 +57,10 @@ class device_opencl : public device_impl { return id; } + cl_device_id &getHandleRef() override{ + return id; + } + bool is_host() const override { return false; } bool is_cpu() const override { return (type == CL_DEVICE_TYPE_CPU); } diff --git a/sycl/include/CL/sycl/detail/program_impl.hpp b/sycl/include/CL/sycl/detail/program_impl.hpp index b1d878d7f6908..f79e758aa4fa7 100644 --- a/sycl/include/CL/sycl/detail/program_impl.hpp +++ b/sycl/include/CL/sycl/detail/program_impl.hpp @@ -51,15 +51,23 @@ class program_impl { } Context = ProgramList[0]->Context; Devices = ProgramList[0]->Devices; + std::vector DevicesSorted; + if (!is_host()) { + DevicesSorted = sort_devices_by_cl_device_id(Devices); + } for (const auto &Prg : ProgramList) { Prg->throw_if_state_is_not(program_state::compiled); if (Prg->Context != Context) { throw invalid_object_error( "Not all programs are associated with the same context"); } - if (Prg->Devices != Devices) { - throw invalid_object_error( - "Not all programs are associated with the same devices"); + if (!is_host()) { + std::vector PrgDevicesSorted = + sort_devices_by_cl_device_id(Prg->Devices); + if (PrgDevicesSorted != DevicesSorted) { + throw invalid_object_error( + "Not all programs are associated with the same devices"); + } } } @@ -92,7 +100,20 @@ class program_impl { CHECK_OCL_CODE(clGetProgramInfo(ClProgram, CL_PROGRAM_DEVICES, sizeof(cl_device_id) * NumDevices, ClDevices.data(), nullptr)); - Devices = vector_class(ClDevices.begin(), ClDevices.end()); + vector_class SyclContextDevices = Context.get_devices(); + + // Keep only the subset of the devices (associated with context) that + // were actually used to create the program. + // This is possible when clCreateProgramWithBinary is used. + auto NewEnd = std::remove_if( + SyclContextDevices.begin(), SyclContextDevices.end(), + [&ClDevices](const sycl::device &Dev) { + return ClDevices.end() == + std::find(ClDevices.begin(), ClDevices.end(), + detail::getSyclObjImpl(Dev)->getHandleRef()); + }); + SyclContextDevices.erase(NewEnd, SyclContextDevices.end()); + Devices = SyclContextDevices; // TODO check build for each device instead cl_program_binary_type BinaryType; CHECK_OCL_CODE(clGetProgramBuildInfo( @@ -371,6 +392,16 @@ class program_impl { return ClKernel; } + std::vector + sort_devices_by_cl_device_id(vector_class Devices) { + std::sort(Devices.begin(), Devices.end(), + [](const device &id1, const device &id2) { + return (detail::getSyclObjImpl(id1)->getHandleRef() < + detail::getSyclObjImpl(id2)->getHandleRef()); + }); + return Devices; + } + void throw_if_state_is(program_state State) const { if (this->State == State) { throw invalid_object_error("Invalid program state");