From 2c63c708027d634c10844caf6f265876630e12fa Mon Sep 17 00:00:00 2001 From: "Fine, Gregory" Date: Thu, 30 May 2024 16:08:59 -0700 Subject: [PATCH 1/2] Add static_asserts to existing autodeduction API --- sycl/include/sycl/ext/intel/esimd/memory.hpp | 31 ++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/intel/esimd/memory.hpp b/sycl/include/sycl/ext/intel/esimd/memory.hpp index 0cd77ee545135..2cad9c1b6b613 100644 --- a/sycl/include/sycl/ext/intel/esimd/memory.hpp +++ b/sycl/include/sycl/ext/intel/esimd/memory.hpp @@ -623,6 +623,10 @@ __ESIMD_API std::enable_if_t< simd> gather(const T *p, OffsetSimdViewT byte_offsets, simd_mask mask, PassThruSimdViewT pass_thru, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); return gather(p, byte_offsets.read(), mask, pass_thru.read(), props); } @@ -662,6 +666,10 @@ __ESIMD_API std::enable_if_t< simd> gather(const T *p, OffsetSimdViewT byte_offsets, simd_mask mask, simd pass_thru, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); return gather(p, byte_offsets.read(), mask, pass_thru, props); } @@ -731,6 +739,10 @@ __ESIMD_API std::enable_if_t< simd> gather(const T *p, OffsetSimdViewT byte_offsets, simd_mask mask, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); return gather(p, byte_offsets.read(), mask, props); } @@ -1012,6 +1024,10 @@ __ESIMD_API std::enable_if_t< ext::oneapi::experimental::is_property_list_v> scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, simd_mask mask, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals.read(), mask, props); } @@ -1116,6 +1132,10 @@ __ESIMD_API std::enable_if_t< ext::oneapi::experimental::is_property_list_v> scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, simd_mask mask, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals, mask, props); } @@ -1150,6 +1170,10 @@ __ESIMD_API std::enable_if_t< ext::oneapi::experimental::is_property_list_v> scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, PropertyListT props = {}) { + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals, props); } @@ -1221,8 +1245,11 @@ __ESIMD_API std::enable_if_t< ext::oneapi::experimental::is_property_list_v> scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, PropertyListT props = {}) { - simd_mask Mask = 1; - scatter(p, byte_offsets.read(), vals.read(), Mask, props); + static_assert(N / VS == + OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), + "Size of pass_thru parameter must correspond to the size of " + "byte_offsets parameter."); + scatter(p, byte_offsets.read(), vals.read(), props); } /// A variation of \c scatter API with \c offsets represented as scalar. From 630f473afe195184f5bcf8ea04d75ece12c48f5f Mon Sep 17 00:00:00 2001 From: "Fine, Gregory" Date: Fri, 31 May 2024 09:18:23 -0700 Subject: [PATCH 2/2] Updated asserts --- sycl/include/sycl/ext/intel/esimd/memory.hpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sycl/include/sycl/ext/intel/esimd/memory.hpp b/sycl/include/sycl/ext/intel/esimd/memory.hpp index 2cad9c1b6b613..2540ad6724ad3 100644 --- a/sycl/include/sycl/ext/intel/esimd/memory.hpp +++ b/sycl/include/sycl/ext/intel/esimd/memory.hpp @@ -739,10 +739,6 @@ __ESIMD_API std::enable_if_t< simd> gather(const T *p, OffsetSimdViewT byte_offsets, simd_mask mask, PropertyListT props = {}) { - static_assert(N / VS == - OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), - "Size of pass_thru parameter must correspond to the size of " - "byte_offsets parameter."); return gather(p, byte_offsets.read(), mask, props); } @@ -1026,7 +1022,7 @@ scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, simd_mask mask, PropertyListT props = {}) { static_assert(N / VS == OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), - "Size of pass_thru parameter must correspond to the size of " + "Size of vals parameter must correspond to the size of " "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals.read(), mask, props); } @@ -1134,7 +1130,7 @@ scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, simd_mask mask, PropertyListT props = {}) { static_assert(N / VS == OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), - "Size of pass_thru parameter must correspond to the size of " + "Size of vals parameter must correspond to the size of " "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals, mask, props); } @@ -1172,7 +1168,7 @@ scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, PropertyListT props = {}) { static_assert(N / VS == OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), - "Size of pass_thru parameter must correspond to the size of " + "Size of vals parameter must correspond to the size of " "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals, props); } @@ -1247,7 +1243,7 @@ scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, PropertyListT props = {}) { static_assert(N / VS == OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(), - "Size of pass_thru parameter must correspond to the size of " + "Size of vals parameter must correspond to the size of " "byte_offsets parameter."); scatter(p, byte_offsets.read(), vals.read(), props); }