Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions sycl/test-e2e/MultiDevice/set_arg_pointer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// UNSUPPORTED: level_zero_v2_adapter
// UNSUPPORTED-TRACKER: CMPLRLLVM-67039

// Test that usm device pointer can be used in a kernel compiled for a context
// with multiple devices.

#include <iostream>
#include <sycl/detail/core.hpp>
#include <sycl/kernel_bundle.hpp>
#include <sycl/platform.hpp>
#include <sycl/usm.hpp>
#include <vector>

using namespace sycl;

class AddIdxKernel;

int main() {
sycl::platform plt;
std::vector<sycl::device> devices = plt.get_devices();
if (devices.size() < 2) {
std::cout << "Need at least 2 GPU devices for this test.\n";
return 0;
}

std::vector<sycl::device> ctx_devices{devices[0], devices[1]};
sycl::context ctx(ctx_devices);

constexpr size_t N = 16;
std::vector<std::vector<int>> results(ctx_devices.size(),
std::vector<int>(N, 0));

// Create a kernel bundle compiled for both devices in the context
auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx);

// For each device, create a queue and run a kernel using device USM
for (size_t i = 0; i < ctx_devices.size(); ++i) {
sycl::queue q(ctx, ctx_devices[i]);
int *data = sycl::malloc_device<int>(N, q);
q.fill(data, 1, N).wait();
q.submit([&](sycl::handler &h) {
h.use_kernel_bundle(kb);
h.parallel_for<AddIdxKernel>(
sycl::range<1>(N), [=](sycl::id<1> idx) { data[idx] += idx[0]; });
}).wait();
q.memcpy(results[i].data(), data, N * sizeof(int)).wait();
sycl::free(data, q);
}

for (size_t i = 0; i < ctx_devices.size(); ++i) {
std::cout << "Device " << i << " results: ";
for (size_t j = 0; j < N; ++j) {
if (results[i][j] != 1 + static_cast<int>(j)) {
return -1;
}
std::cout << results[i][j] << " ";
}
}
return 0;
}
14 changes: 9 additions & 5 deletions unified-runtime/source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,12 +1004,16 @@ ur_result_t setKernelPendingArguments(
ze_kernel_handle_t ZeKernel) {
// If there are any pending arguments set them now.
for (auto &Arg : PendingArguments) {
// The ArgValue may be a NULL pointer in which case a NULL value is used for
// the kernel argument declared as a pointer to global or constant memory.
char **ZeHandlePtr = nullptr;
if (Arg.Value) {
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
nullptr, 0u));
if (auto MemObjPtr = std::get_if<ur_mem_handle_t>(&Arg.Value)) {
ur_mem_handle_t MemObj = *MemObjPtr;
if (MemObj) {
UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
nullptr, 0u));
}
} else {
auto Ptr = const_cast<void **>(&std::get<const void *>(Arg.Value));
ZeHandlePtr = reinterpret_cast<char **>(Ptr);
}
ZE2UR_CALL(zeKernelSetArgumentValue,
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,25 @@ ur_result_t calculateKernelWorkDimensions(

return UR_RESULT_SUCCESS;
}

ur_result_t setArgValueOnZeKernel(ze_kernel_handle_t hZeKernel,
uint32_t argIndex, size_t argSize,
const void *pArgValue) {
// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
// in which case a NULL value will be used as the value for the argument
// declared as a pointer to global or constant memory in the kernel"
//
// We don't know the type of the argument but it seems that the only time
// SYCL RT would send a pointer to NULL in 'arg_value' is when the argument
// is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL.
if (argSize == sizeof(void *) && pArgValue &&
*(void **)(const_cast<void *>(pArgValue)) == nullptr) {
pArgValue = nullptr;
}

ze_result_t ZeResult = ZE_CALL_NOCHECK(
zeKernelSetArgumentValue, (hZeKernel, argIndex, argSize, pArgValue));
if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT)
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
return ze2urResult(ZeResult);
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,15 @@ inline void postSubmit(ze_kernel_handle_t hZeKernel,
zeKernelSetGlobalOffsetExp(hZeKernel, 0, 0, 0);
}
}

/**
* Helper to set kernel argument for ze_kernel_handle_t.
* @param[in] hZeKernel The handle to the Level-Zero kernel.
* @param[in] argIndex The index of the argument to set.
* @param[in] argSize The size of the argument to set.
* @param[in] pArgValue The pointer to the argument value.
* @return UR_RESULT_SUCCESS or an error code on failure
*/
ur_result_t setArgValueOnZeKernel(ze_kernel_handle_t hZeKernel,
uint32_t argIndex, size_t argSize,
const void *pArgValue);
65 changes: 34 additions & 31 deletions unified-runtime/source/adapters/level_zero/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,22 @@ ur_result_t urEnqueueKernelLaunch(

// If there are any pending arguments set them now.
for (auto &Arg : Kernel->PendingArguments) {
// The ArgValue may be a NULL pointer in which case a NULL value is used for
// the kernel argument declared as a pointer to global or constant memory.
// The Arg.Value can be either a ur_mem_handle_t or a raw pointer
// (const void*). Resolve per-device: for mem handles obtain the device
// specific handle, otherwise pass the raw pointer value.
char **ZeHandlePtr = nullptr;
if (Arg.Value) {
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
Queue->Device, EventWaitList,
NumEventsInWaitList));
if (auto MemObjPtr = std::get_if<ur_mem_handle_t>(&Arg.Value)) {
ur_mem_handle_t MemObj = *MemObjPtr;
if (MemObj) {
UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
Queue->Device, EventWaitList,
NumEventsInWaitList));
}
} else {
auto Ptr = const_cast<void **>(&std::get<const void *>(Arg.Value));
ZeHandlePtr = reinterpret_cast<char **>(Ptr);
}
ZE2UR_CALL(zeKernelSetArgumentValue,
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
UR_CALL(setArgValueOnZeKernel(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
}
Kernel->PendingArguments.clear();

Expand Down Expand Up @@ -422,41 +428,21 @@ ur_result_t urKernelSetArgValue(

UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);

// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
// in which case a NULL value will be used as the value for the argument
// declared as a pointer to global or constant memory in the kernel"
//
// We don't know the type of the argument but it seems that the only time
// SYCL RT would send a pointer to NULL in 'arg_value' is when the argument
// is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL.
if (ArgSize == sizeof(void *) && PArgValue &&
*(void **)(const_cast<void *>(PArgValue)) == nullptr) {
PArgValue = nullptr;
}

if (ArgIndex > Kernel->ZeKernelProperties->numKernelArgs - 1) {
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
}

std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
ze_result_t ZeResult = ZE_RESULT_SUCCESS;
if (Kernel->ZeKernelMap.empty()) {
auto ZeKernel = Kernel->ZeKernel;
ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue,
(ZeKernel, ArgIndex, ArgSize, PArgValue));
UR_CALL(setArgValueOnZeKernel(ZeKernel, ArgIndex, ArgSize, PArgValue))
} else {
for (auto It : Kernel->ZeKernelMap) {
auto ZeKernel = It.second;
ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue,
(ZeKernel, ArgIndex, ArgSize, PArgValue));
UR_CALL(setArgValueOnZeKernel(ZeKernel, ArgIndex, ArgSize, PArgValue))
}
}

if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
}

return ze2urResult(ZeResult);
return UR_RESULT_SUCCESS;
}

ur_result_t urKernelSetArgLocal(
Expand Down Expand Up @@ -732,6 +718,23 @@ ur_result_t urKernelSetArgPointer(
/// [in][optional] SVM pointer to memory location holding the argument
/// value. If null then argument value is considered null.
const void *ArgValue) {
UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
{
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
// In multi-device context instead of setting pointer arguments immediately
// across all device kernels, store them as pending so they can be resolved
// per-device at enqueue time. This ensures the correct handle is used for
// the device of the queue.
if (Kernel->Program->Context->getDevices().size() > 1) {
if (ArgIndex > Kernel->ZeKernelProperties->numKernelArgs - 1) {
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
}
Kernel->PendingArguments.push_back({ArgIndex, sizeof(const void *),
ArgValue, ur_mem_handle_t_::unknown});

return UR_RESULT_SUCCESS;
}
}

// KernelSetArgValue is expecting a pointer to the argument
UR_CALL(ur::level_zero::urKernelSetArgValue(
Expand Down
7 changes: 5 additions & 2 deletions unified-runtime/source/adapters/level_zero/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#pragma once

#include <unordered_set>
#include <variant>

#include "common.hpp"
#include "common/ur_ref_count.hpp"
Expand Down Expand Up @@ -97,8 +98,10 @@ struct ur_kernel_handle_t_ : ur_object {
struct ArgumentInfo {
uint32_t Index;
size_t Size;
// const ur_mem_handle_t_ *Value;
ur_mem_handle_t_ *Value;
// Value may be either a memory object or a raw pointer value (for pointer
// arguments). Resolve at enqueue time per-device to ensure correct handle
// is used for that device.
std::variant<ur_mem_handle_t, const void *> Value;
ur_mem_handle_t_::access_mode_t AccessMode{ur_mem_handle_t_::unknown};
};
// Arguments that still need to be set (with zeKernelSetArgumentValue)
Expand Down