diff --git a/sycl/test-e2e/MultiDevice/set_arg_pointer.cpp b/sycl/test-e2e/MultiDevice/set_arg_pointer.cpp new file mode 100644 index 0000000000000..39b7ceb79488e --- /dev/null +++ b/sycl/test-e2e/MultiDevice/set_arg_pointer.cpp @@ -0,0 +1,63 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: level_zero_v2_adapter +// UNSUPPORTED-TRACKER: CMPLRLLVM-67039 + +// Test that usm device pointer can be used in a kernel compiled for a context +// with multiple devices. + +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +class AddIdxKernel; + +int main() { + sycl::platform plt; + std::vector devices = plt.get_devices(); + if (devices.size() < 2) { + std::cout << "Need at least 2 GPU devices for this test.\n"; + return 0; + } + + std::vector ctx_devices{devices[0], devices[1]}; + sycl::context ctx(ctx_devices); + + constexpr size_t N = 16; + std::vector> results(ctx_devices.size(), + std::vector(N, 0)); + + // Create a kernel bundle compiled for both devices in the context + auto kb = sycl::get_kernel_bundle(ctx); + + // For each device, create a queue and run a kernel using device USM + for (size_t i = 0; i < ctx_devices.size(); ++i) { + sycl::queue q(ctx, ctx_devices[i]); + int *data = sycl::malloc_device(N, q); + q.fill(data, 1, N).wait(); + q.submit([&](sycl::handler &h) { + h.use_kernel_bundle(kb); + h.parallel_for( + sycl::range<1>(N), [=](sycl::id<1> idx) { data[idx] += idx[0]; }); + }).wait(); + q.memcpy(results[i].data(), data, N * sizeof(int)).wait(); + sycl::free(data, q); + } + + for (size_t i = 0; i < ctx_devices.size(); ++i) { + std::cout << "Device " << i << " results: "; + for (size_t j = 0; j < N; ++j) { + if (results[i][j] != 1 + static_cast(j)) { + return -1; + } + std::cout << results[i][j] << " "; + } + } + return 0; +} diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/command_buffer.cpp index 687c905417d8b..25d45f7232636 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.cpp @@ -1004,12 +1004,16 @@ ur_result_t setKernelPendingArguments( ze_kernel_handle_t ZeKernel) { // If there are any pending arguments set them now. for (auto &Arg : PendingArguments) { - // The ArgValue may be a NULL pointer in which case a NULL value is used for - // the kernel argument declared as a pointer to global or constant memory. char **ZeHandlePtr = nullptr; - if (Arg.Value) { - UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device, - nullptr, 0u)); + if (auto MemObjPtr = std::get_if(&Arg.Value)) { + ur_mem_handle_t MemObj = *MemObjPtr; + if (MemObj) { + UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device, + nullptr, 0u)); + } + } else { + auto Ptr = const_cast(&std::get(Arg.Value)); + ZeHandlePtr = reinterpret_cast(Ptr); } ZE2UR_CALL(zeKernelSetArgumentValue, (ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); diff --git a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp index 97aac29a84fbe..a8c75e41e44da 100644 --- a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp +++ b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp @@ -156,3 +156,25 @@ ur_result_t calculateKernelWorkDimensions( return UR_RESULT_SUCCESS; } + +ur_result_t setArgValueOnZeKernel(ze_kernel_handle_t hZeKernel, + uint32_t argIndex, size_t argSize, + const void *pArgValue) { + // OpenCL: "the arg_value pointer can be NULL or point to a NULL value + // in which case a NULL value will be used as the value for the argument + // declared as a pointer to global or constant memory in the kernel" + // + // We don't know the type of the argument but it seems that the only time + // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument + // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL. + if (argSize == sizeof(void *) && pArgValue && + *(void **)(const_cast(pArgValue)) == nullptr) { + pArgValue = nullptr; + } + + ze_result_t ZeResult = ZE_CALL_NOCHECK( + zeKernelSetArgumentValue, (hZeKernel, argIndex, argSize, pArgValue)); + if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) + return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE; + return ze2urResult(ZeResult); +} diff --git a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp index 5dcf0c9123045..cc090bed8d7d9 100644 --- a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp +++ b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp @@ -71,3 +71,15 @@ inline void postSubmit(ze_kernel_handle_t hZeKernel, zeKernelSetGlobalOffsetExp(hZeKernel, 0, 0, 0); } } + +/** + * Helper to set kernel argument for ze_kernel_handle_t. + * @param[in] hZeKernel The handle to the Level-Zero kernel. + * @param[in] argIndex The index of the argument to set. + * @param[in] argSize The size of the argument to set. + * @param[in] pArgValue The pointer to the argument value. + * @return UR_RESULT_SUCCESS or an error code on failure + */ +ur_result_t setArgValueOnZeKernel(ze_kernel_handle_t hZeKernel, + uint32_t argIndex, size_t argSize, + const void *pArgValue); diff --git a/unified-runtime/source/adapters/level_zero/kernel.cpp b/unified-runtime/source/adapters/level_zero/kernel.cpp index 45b7b087cece5..bcac9cb04c320 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/kernel.cpp @@ -125,16 +125,22 @@ ur_result_t urEnqueueKernelLaunch( // If there are any pending arguments set them now. for (auto &Arg : Kernel->PendingArguments) { - // The ArgValue may be a NULL pointer in which case a NULL value is used for - // the kernel argument declared as a pointer to global or constant memory. + // The Arg.Value can be either a ur_mem_handle_t or a raw pointer + // (const void*). Resolve per-device: for mem handles obtain the device + // specific handle, otherwise pass the raw pointer value. char **ZeHandlePtr = nullptr; - if (Arg.Value) { - UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, - Queue->Device, EventWaitList, - NumEventsInWaitList)); + if (auto MemObjPtr = std::get_if(&Arg.Value)) { + ur_mem_handle_t MemObj = *MemObjPtr; + if (MemObj) { + UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, + Queue->Device, EventWaitList, + NumEventsInWaitList)); + } + } else { + auto Ptr = const_cast(&std::get(Arg.Value)); + ZeHandlePtr = reinterpret_cast(Ptr); } - ZE2UR_CALL(zeKernelSetArgumentValue, - (ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); + UR_CALL(setArgValueOnZeKernel(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); } Kernel->PendingArguments.clear(); @@ -422,41 +428,21 @@ ur_result_t urKernelSetArgValue( UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); - // OpenCL: "the arg_value pointer can be NULL or point to a NULL value - // in which case a NULL value will be used as the value for the argument - // declared as a pointer to global or constant memory in the kernel" - // - // We don't know the type of the argument but it seems that the only time - // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument - // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL. - if (ArgSize == sizeof(void *) && PArgValue && - *(void **)(const_cast(PArgValue)) == nullptr) { - PArgValue = nullptr; - } - if (ArgIndex > Kernel->ZeKernelProperties->numKernelArgs - 1) { return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX; } std::scoped_lock Guard(Kernel->Mutex); - ze_result_t ZeResult = ZE_RESULT_SUCCESS; if (Kernel->ZeKernelMap.empty()) { auto ZeKernel = Kernel->ZeKernel; - ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, - (ZeKernel, ArgIndex, ArgSize, PArgValue)); + UR_CALL(setArgValueOnZeKernel(ZeKernel, ArgIndex, ArgSize, PArgValue)) } else { for (auto It : Kernel->ZeKernelMap) { auto ZeKernel = It.second; - ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, - (ZeKernel, ArgIndex, ArgSize, PArgValue)); + UR_CALL(setArgValueOnZeKernel(ZeKernel, ArgIndex, ArgSize, PArgValue)) } } - - if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) { - return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE; - } - - return ze2urResult(ZeResult); + return UR_RESULT_SUCCESS; } ur_result_t urKernelSetArgLocal( @@ -732,6 +718,23 @@ ur_result_t urKernelSetArgPointer( /// [in][optional] SVM pointer to memory location holding the argument /// value. If null then argument value is considered null. const void *ArgValue) { + UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + { + std::scoped_lock Guard(Kernel->Mutex); + // In multi-device context instead of setting pointer arguments immediately + // across all device kernels, store them as pending so they can be resolved + // per-device at enqueue time. This ensures the correct handle is used for + // the device of the queue. + if (Kernel->Program->Context->getDevices().size() > 1) { + if (ArgIndex > Kernel->ZeKernelProperties->numKernelArgs - 1) { + return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX; + } + Kernel->PendingArguments.push_back({ArgIndex, sizeof(const void *), + ArgValue, ur_mem_handle_t_::unknown}); + + return UR_RESULT_SUCCESS; + } + } // KernelSetArgValue is expecting a pointer to the argument UR_CALL(ur::level_zero::urKernelSetArgValue( diff --git a/unified-runtime/source/adapters/level_zero/kernel.hpp b/unified-runtime/source/adapters/level_zero/kernel.hpp index 131dba270c05d..38e2e43e366b6 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/kernel.hpp @@ -10,6 +10,7 @@ #pragma once #include +#include #include "common.hpp" #include "common/ur_ref_count.hpp" @@ -97,8 +98,10 @@ struct ur_kernel_handle_t_ : ur_object { struct ArgumentInfo { uint32_t Index; size_t Size; - // const ur_mem_handle_t_ *Value; - ur_mem_handle_t_ *Value; + // Value may be either a memory object or a raw pointer value (for pointer + // arguments). Resolve at enqueue time per-device to ensure correct handle + // is used for that device. + std::variant Value; ur_mem_handle_t_::access_mode_t AccessMode{ur_mem_handle_t_::unknown}; }; // Arguments that still need to be set (with zeKernelSetArgumentValue)