Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sycl/include/sycl/ext/oneapi/bindless_images.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,15 +1036,19 @@ DataT fetch_image(const sampled_image_handle &imageHandle [[maybe_unused]],
"HintT must always be a recognized standard type");

#ifdef __SYCL_DEVICE_ONLY__
// Convert the raw handle to an image and use FETCH_UNSAMPLED_IMAGE since
// fetch_image should not use the sampler
if constexpr (detail::is_recognized_standard_type<DataT>()) {
return FETCH_SAMPLED_IMAGE(
return FETCH_UNSAMPLED_IMAGE(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note:

  1. Some of these macros seem a little unnecessary, i.e. some of the just wrap the same function in all paths, which is in turn no longer than the macro itself.
  2. The macros bleed into user-space, i.e. we don't undef them at the end of the header. We should probably address that. Maybe we should move away from macros and use inline or anonymous functions instead?

Would you mind opening a follow-up fix for this and/or open an issue?

DataT,
CONVERT_HANDLE_TO_SAMPLED_IMAGE(imageHandle.raw_handle, coordSize),
CONVERT_HANDLE_TO_IMAGE(imageHandle.raw_handle,
detail::OCLImageTyRead<coordSize>),
coords);
} else {
return sycl::bit_cast<DataT>(FETCH_SAMPLED_IMAGE(
return sycl::bit_cast<DataT>(FETCH_UNSAMPLED_IMAGE(
HintT,
CONVERT_HANDLE_TO_SAMPLED_IMAGE(imageHandle.raw_handle, coordSize),
CONVERT_HANDLE_TO_IMAGE(imageHandle.raw_handle,
detail::OCLImageTyRead<coordSize>),
coords));
}
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
#include <sycl/ext/oneapi/bindless_images.hpp>
#include <sycl/usm.hpp>

class kernel_sampled_fetch;
namespace {

int main() {
template <typename T, sycl::image_channel_type ChanType>
static int testSampledImageFetch() {

sycl::device dev;
sycl::queue q(dev);
Expand All @@ -23,9 +24,9 @@ int main() {
constexpr size_t width = 5;
constexpr size_t height = 6;
constexpr size_t N = width * height;
std::vector<sycl::vec<uint16_t, 4>> out(N);
std::vector<sycl::vec<uint16_t, 4>> expected(N);
std::vector<sycl::vec<uint16_t, 4>> dataIn(N);
std::vector<sycl::vec<T, 4>> out(N);
std::vector<sycl::vec<T, 4>> expected(N);
std::vector<sycl::vec<T, 4>> dataIn(N);
for (int i = 0; i < width; i++) {
for (int j = 0; j < height; j++) {
auto index = i + (width * j);
Expand All @@ -43,8 +44,7 @@ int main() {
sycl::filtering_mode::linear);

// Extension: image descriptor
syclexp::image_descriptor desc({width, height}, 4,
sycl::image_channel_type::unsigned_int16);
syclexp::image_descriptor desc({width, height}, 4, ChanType);
size_t pitch = 0;

// Extension: returns the device pointer to USM allocated pitched memory
Expand All @@ -65,21 +65,20 @@ int main() {

sycl::buffer buf(out.data(), sycl::range{height, width});
q.submit([&](sycl::handler &cgh) {
auto outAcc = buf.get_access<sycl::access_mode::write>(
auto outAcc = buf.template get_access<sycl::access_mode::write>(
cgh, sycl::range<2>{height, width});

cgh.parallel_for<kernel_sampled_fetch>(
sycl::nd_range<2>{{width, height}, {width, height}},
[=](sycl::nd_item<2> it) {
size_t dim0 = it.get_local_id(0);
size_t dim1 = it.get_local_id(1);
cgh.parallel_for(sycl::nd_range<2>{{width, height}, {width, height}},
[=](sycl::nd_item<2> it) {
size_t dim0 = it.get_local_id(0);
size_t dim1 = it.get_local_id(1);

// Extension: fetch data from sampled image handle
auto px1 = syclexp::fetch_image<sycl::vec<uint16_t, 4>>(
imgHandle, sycl::int2(dim0, dim1));
// Extension: fetch data from sampled image handle
auto px1 = syclexp::fetch_image<sycl::vec<T, 4>>(
imgHandle, sycl::int2(dim0, dim1));

outAcc[sycl::id<2>{dim1, dim0}] = px1;
});
outAcc[sycl::id<2>{dim1, dim0}] = px1;
});
});

q.wait_and_throw();
Expand Down Expand Up @@ -121,3 +120,23 @@ int main() {
std::cout << "Test failed!" << std::endl;
return 3;
}

} // namespace

int main() {
if (int err =
testSampledImageFetch<uint16_t,
sycl::image_channel_type::unsigned_int16>()) {
return err;
}
if (int err =
testSampledImageFetch<uint32_t,
sycl::image_channel_type::unsigned_int32>()) {
return err;
}
if (int err =
testSampledImageFetch<float, sycl::image_channel_type::fp32>()) {
return err;
}
return 0;
}