Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sycl/include/sycl/ext/oneapi/get_kernel_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ template <typename KernelName, typename Param>
typename sycl::detail::is_kernel_device_specific_info_desc<Param>::return_type
get_kernel_info(const context &Ctx, const device &Dev) {
auto Bundle =
sycl::get_kernel_bundle<KernelName, sycl::bundle_state::executable>(Ctx);
sycl::get_kernel_bundle<KernelName, sycl::bundle_state::executable>(
Ctx, {Dev});
return Bundle.template get_kernel<KernelName>().template get_info<Param>(Dev);
}

Expand All @@ -49,7 +50,7 @@ typename sycl::detail::is_kernel_device_specific_info_desc<Param>::return_type
get_kernel_info(const queue &Q) {
auto Bundle =
sycl::get_kernel_bundle<KernelName, sycl::bundle_state::executable>(
Q.get_context());
Q.get_context(), {Q.get_device()});
return Bundle.template get_kernel<KernelName>().template get_info<Param>(
Q.get_device());
}
Expand All @@ -73,7 +74,7 @@ std::enable_if_t<ext::oneapi::experimental::is_kernel_v<Func>,
Param>::return_type>
get_kernel_info(const context &ctxt, const device &dev) {
auto Bundle = sycl::ext::oneapi::experimental::get_kernel_bundle<
Func, sycl::bundle_state::executable>(ctxt);
Func, sycl::bundle_state::executable>(ctxt, {dev});
return Bundle.template ext_oneapi_get_kernel<Func>().template get_info<Param>(
dev);
}
Expand Down
1 change: 1 addition & 0 deletions sycl/unittests/kernel-and-program/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_sycl_unittest(KernelAndProgramTests OBJECT
OutOfResources.cpp
InMemCacheEviction.cpp
KernelArgs.cpp
KernelInfoShortcuts.cpp
)
target_compile_definitions(KernelAndProgramTests_non_preview PRIVATE __SYCL_INTERNAL_API)
target_compile_definitions(KernelAndProgramTests_preview PRIVATE __SYCL_INTERNAL_API __INTEL_PREVIEW_BREAKING_CHANGES)
73 changes: 73 additions & 0 deletions sycl/unittests/kernel-and-program/KernelInfoShortcuts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//==-------------------------- KernelInfoShortcuts.cpp -------------------==//
//
// Unit test to ensure get_kernel_info for a device queries/uses kernel bundle
// for that specific device only and doesn't trigger builds for all devices.
//

#include <helpers/MockDeviceImage.hpp>
#include <helpers/MockKernelInfo.hpp>
#include <helpers/ScopedEnvVar.hpp>
#include <helpers/UrMock.hpp>
#include <sycl/sycl.hpp>

#include <gtest/gtest.h>

using namespace sycl;
using namespace sycl::unittest;

class ShortcutKernelInfoTestKernel;
MOCK_INTEGRATION_HEADER(ShortcutKernelInfoTestKernel)

static int ProgramBuildCounter = 0;
static ur_result_t redefinedurProgramBuild(void *pParams) {
++ProgramBuildCounter;
return UR_RESULT_SUCCESS;
}

static ur_result_t redefinedDeviceGet(void *pParams) {
auto params = *static_cast<ur_device_get_params_t *>(pParams);
if (*params.ppNumDevices) {
**params.ppNumDevices = 2; // two devices total
return UR_RESULT_SUCCESS;
}
if (*params.pphDevices) {
// provide two mock device handles
(*params.pphDevices)[0] = reinterpret_cast<ur_device_handle_t>(0x1);
(*params.pphDevices)[1] = reinterpret_cast<ur_device_handle_t>(0x2);
}
return UR_RESULT_SUCCESS;
}

ur_result_t redefinedurKernelGetGroupInfo(void *pParams) {
return UR_RESULT_SUCCESS;
}

TEST(ShortcutKernelInfo, QueryInfoForSingleDevice) {
unittest::UrMock<> Mock;
static sycl::unittest::MockDeviceImage DevImage =
sycl::unittest::generateDefaultImage({"ShortcutKernelInfoTestKernel"});
static sycl::unittest::MockDeviceImageArray<1> DevImageArray = {&DevImage};

mock::getCallbacks().set_replace_callback("urDeviceGet", &redefinedDeviceGet);
mock::getCallbacks().set_replace_callback("urProgramBuildExp",
&redefinedurProgramBuild);
mock::getCallbacks().set_replace_callback("urKernelGetGroupInfo",
&redefinedurKernelGetGroupInfo);

platform Plt = platform();
std::vector<device> Devs = Plt.get_devices();
ASSERT_GE(Devs.size(), 2u) << "Test requires at least 2 devices";
context Ctx = context(Devs);
queue Queue = queue(Ctx, Devs[0]);

// Query kernel info for the first device only
ProgramBuildCounter = 0;
sycl::ext::oneapi::get_kernel_info<
ShortcutKernelInfoTestKernel,
sycl::info::kernel_device_specific::work_group_size>(Ctx, Devs[0]);
sycl::ext::oneapi::get_kernel_info<
ShortcutKernelInfoTestKernel,
sycl::info::kernel_device_specific::work_group_size>(Queue);

EXPECT_EQ(ProgramBuildCounter, 1);
}
Loading