Skip to content

[SYCL] Fix get() method for non-opencl backends #3070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sycl/source/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ context::context(const vector_class<device> &DeviceList,
PropList);
else {
const device &NonHostDevice = *NonHostDeviceIter;
const auto &NonHostPlatform = NonHostDevice.get_platform().get();
const auto &NonHostPlatform =
detail::getSyclObjImpl(NonHostDevice.get_platform())->getHandleRef();
if (std::any_of(DeviceList.begin(), DeviceList.end(),
[&](const device &CurrentDevice) {
return (CurrentDevice.is_host() ||
(CurrentDevice.get_platform().get() !=
NonHostPlatform));
return (
CurrentDevice.is_host() ||
(detail::getSyclObjImpl(CurrentDevice.get_platform())
->getHandleRef() != NonHostPlatform));
}))
throw invalid_parameter_error(
"Can't add devices across platforms to a single context.",
Expand Down
14 changes: 7 additions & 7 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ context_impl::context_impl(RT::PiContext PiContext, async_handler AsyncHandler,
}

cl_context context_impl::get() const {
if (!MHostContext) {
// TODO catch an exception and put it to list of asynchronous exceptions
getPlugin().call<PiApiKind::piContextRetain>(MContext);
return pi::cast<cl_context>(MContext);
if (MHostContext || getPlugin().getBackend() != cl::sycl::backend::opencl) {
throw invalid_object_error(
"This instance of context doesn't support OpenCL interoperability.",
PI_INVALID_CONTEXT);
}
throw invalid_object_error(
"This instance of context doesn't support OpenCL interoperability.",
PI_INVALID_CONTEXT);
// TODO catch an exception and put it to list of asynchronous exceptions
getPlugin().call<PiApiKind::piContextRetain>(MContext);
return pi::cast<cl_context>(MContext);
}

bool context_impl::is_host() const { return MHostContext; }
Expand Down
13 changes: 6 additions & 7 deletions sycl/source/detail/device_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,13 @@ bool device_impl::is_affinity_supported(
}

cl_device_id device_impl::get() const {
if (MIsHostDevice)
throw invalid_object_error("This instance of device is a host instance",
PI_INVALID_DEVICE);

const detail::plugin &Plugin = getPlugin();

if (MIsHostDevice || getPlugin().getBackend() != cl::sycl::backend::opencl) {
throw invalid_object_error(
"This instance of device doesn't support OpenCL interoperability.",
PI_INVALID_DEVICE);
}
// TODO catch an exception and put it to list of asynchronous exceptions
Plugin.call<PiApiKind::piDeviceRetain>(MDevice);
getPlugin().call<PiApiKind::piDeviceRetain>(MDevice);
return pi::cast<cl_device_id>(getNative());
}

Expand Down
13 changes: 7 additions & 6 deletions sycl/source/detail/event_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ extern xpti::trace_event_data_t *GSYCLGraphEvent;
bool event_impl::is_host() const { return MHostEvent || !MOpenCLInterop; }

cl_event event_impl::get() const {
if (MOpenCLInterop) {
getPlugin().call<PiApiKind::piEventRetain>(MEvent);
return pi::cast<cl_event>(MEvent);
if (!MOpenCLInterop ||
getPlugin().getBackend() != cl::sycl::backend::opencl) {
throw invalid_object_error(
"This instance of event doesn't support OpenCL interoperability.",
PI_INVALID_EVENT);
}
throw invalid_object_error(
"This instance of event doesn't support OpenCL interoperability.",
PI_INVALID_EVENT);
getPlugin().call<PiApiKind::piEventRetain>(MEvent);
return pi::cast<cl_event>(MEvent);
}

event_impl::~event_impl() {
Expand Down
8 changes: 5 additions & 3 deletions sycl/source/detail/kernel_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ class kernel_impl {
///
/// \return a valid cl_kernel instance
cl_kernel get() const {
if (is_host())
throw invalid_object_error("This instance of kernel is a host instance",
PI_INVALID_KERNEL);
if (is_host() || getPlugin().getBackend() != cl::sycl::backend::opencl) {
throw invalid_object_error(
"This instance of kernel doesn't support OpenCL interoperability.",
PI_INVALID_KERNEL);
}
getPlugin().call<PiApiKind::piKernelRetain>(MKernel);
return pi::cast<cl_kernel>(MKernel);
}
Expand Down
9 changes: 5 additions & 4 deletions sycl/source/detail/platform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ class platform_impl {

/// \return an instance of OpenCL cl_platform_id.
cl_platform_id get() const {
if (is_host())
throw invalid_object_error("This instance of platform is a host instance",
PI_INVALID_PLATFORM);

if (is_host() || getPlugin().getBackend() != cl::sycl::backend::opencl) {
throw invalid_object_error(
"This instance of platform doesn't support OpenCL interoperability.",
PI_INVALID_PLATFORM);
}
return pi::cast<cl_platform_id>(MPlatform);
}

Expand Down
10 changes: 5 additions & 5 deletions sycl/source/detail/program_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,12 @@ program_impl::~program_impl() {

cl_program program_impl::get() const {
throw_if_state_is(program_state::none);
if (is_host()) {
throw invalid_object_error("This instance of program is a host instance",
PI_INVALID_PROGRAM);
if (is_host() || getPlugin().getBackend() != cl::sycl::backend::opencl) {
throw invalid_object_error(
"This instance of program doesn't support OpenCL interoperability.",
PI_INVALID_PROGRAM);
}
const detail::plugin &Plugin = getPlugin();
Plugin.call<PiApiKind::piProgramRetain>(MProgram);
getPlugin().call<PiApiKind::piProgramRetain>(MProgram);
return pi::cast<cl_program>(MProgram);
}

Expand Down
12 changes: 6 additions & 6 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ class queue_impl {

/// \return an OpenCL interoperability queue handle.
cl_command_queue get() {
if (!MHostQueue) {
getPlugin().call<PiApiKind::piQueueRetain>(MQueues[0]);
return pi::cast<cl_command_queue>(MQueues[0]);
if (MHostQueue || getPlugin().getBackend() != cl::sycl::backend::opencl) {
throw invalid_object_error(
"This instance of queue doesn't support OpenCL interoperability",
PI_INVALID_QUEUE);
}
throw invalid_object_error(
"This instance of queue doesn't support OpenCL interoperability",
PI_INVALID_QUEUE);
getPlugin().call<PiApiKind::piQueueRetain>(MQueues[0]);
return pi::cast<cl_command_queue>(MQueues[0]);
}

/// \return an associated SYCL context.
Expand Down
2 changes: 0 additions & 2 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2010,8 +2010,6 @@ cl_int ExecCGCommand::enqueueImp() {
ExecInterop->MInteropTask->call(InteropHandler);
Plugin.call<PiApiKind::piEnqueueEventsWait>(MQueue->getHandleRef(), 0,
nullptr, &Event);
Plugin.call<PiApiKind::piQueueRelease>(
reinterpret_cast<pi_queue>(MQueue->get()));

return CL_SUCCESS;
}
Expand Down
12 changes: 9 additions & 3 deletions sycl/test/basic_tests/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ int main() {
for (const auto &plt : platform::get_platforms()) {
std::cout << "Platform " << i++
<< " is available: " << ((plt.is_host()) ? "host: " : "OpenCL: ")
<< std::hex << ((plt.is_host()) ? nullptr : plt.get())
<< std::hex
<< ((plt.is_host() ||
plt.get_backend() != cl::sycl::backend::opencl)
? nullptr
: plt.get())
<< std::endl;
}

Expand All @@ -34,7 +38,8 @@ int main() {
platform MovedPlatform(std::move(Platform));
assert(hash == hash_class<platform>()(MovedPlatform));
assert(platformA.is_host() == MovedPlatform.is_host());
if (!platformA.is_host()) {
if (!platformA.is_host() &&
platformA.get_backend() == cl::sycl::backend::opencl) {
assert(MovedPlatform.get() != nullptr);
}
}
Expand All @@ -46,7 +51,8 @@ int main() {
WillMovedPlatform = std::move(Platform);
assert(hash == hash_class<platform>()(WillMovedPlatform));
assert(platformA.is_host() == WillMovedPlatform.is_host());
if (!platformA.is_host()) {
if (!platformA.is_host() &&
platformA.get_backend() == cl::sycl::backend::opencl) {
assert(WillMovedPlatform.get() != nullptr);
}
}
Expand Down
7 changes: 5 additions & 2 deletions sycl/unittests/scheduler/CommandsWaitForEvents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ pi_result getEventInfoFunc(pi_event Event, pi_event_info PName, size_t PVSize,

if (Event == TestContext->EventCtx1)
*reinterpret_cast<pi_context *>(PV) =
reinterpret_cast<pi_context>(TestContext->Ctx1->get());
reinterpret_cast<pi_context>(TestContext->Ctx1->getHandleRef());
else if (Event == TestContext->EventCtx2)
*reinterpret_cast<pi_context *>(PV) =
reinterpret_cast<pi_context>(TestContext->Ctx2->get());
reinterpret_cast<pi_context>(TestContext->Ctx2->getHandleRef());

return PI_SUCCESS;
}
Expand Down Expand Up @@ -109,4 +109,7 @@ TEST_F(SchedulerTest, CommandsWaitForEvents) {
ASSERT_TRUE(TestContext->EventCtx1WasWaited &&
TestContext->EventCtx2WasWaited)
<< "Not all events were waited for";
delete TestContext.release(); // explicitly delete here is important for CUDA
// BE to ensure that cuda driver is still in
// memory while cuda objects are being freed.
}