diff --git a/sycl/include/sycl/ext/intel/esimd/memory.hpp b/sycl/include/sycl/ext/intel/esimd/memory.hpp index 0cd77ee545135..2540ad6724ad3 100644 --- a/sycl/include/sycl/ext/intel/esimd/memory.hpp +++ b/sycl/include/sycl/ext/intel/esimd/memory.hpp @@ -623,6 +623,10 @@ __ESIMD_API std::enable_if_t< simd> gather(const T *p, OffsetSimdViewT byte_offsets, simd_mask mask, PassThruSimdViewT pass_thru, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); return gather(p, byte_offsets.read(), mask, pass_thru.read(), props); } @@ -662,6 +666,10 @@ __ESIMD_API std::enable_if_t< simd> gather(const T *p, OffsetSimdViewT byte_offsets, simd_mask mask, simd pass_thru, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); return gather(p, byte_offsets.read(), mask, pass_thru, props); } @@ -1012,6 +1020,10 @@ __ESIMD_API std::enable_if_t< ext::oneapi::experimental::is_property_list_v> scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, simd_mask mask, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of vals parameter must correspond to the size of " + "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals.read(), mask, props); } @@ -1116,6 +1128,10 @@ __ESIMD_API std::enable_if_t< ext::oneapi::experimental::is_property_list_v> scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, simd_mask mask, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of vals parameter must correspond to the size of " + "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals, mask, props); } @@ -1150,6 +1166,10 @@ __ESIMD_API std::enable_if_t< ext::oneapi::experimental::is_property_list_v> scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of vals parameter must correspond to the size of " + "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals, props); } @@ -1221,8 +1241,11 @@ __ESIMD_API std::enable_if_t< ext::oneapi::experimental::is_property_list_v> scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, PropertyListT props = {}) { - simd_mask Mask = 1; - scatter(p, byte_offsets.read(), vals.read(), Mask, props); + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of vals parameter must correspond to the size of " + "byte_offsets parameter."); + scatter(p, byte_offsets.read(), vals.read(), props); } /// A variation of \c scatter API with \c offsets represented as scalar.