11#include " binbcast.hpp"
22
3+ #include < array>
34#include < cstddef>
45#include < cstdint>
56#include < sycl/sycl.hpp>
67
8+ #include " dpct/helper.hpp"
79#include " ggml.h"
810
9- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
10- static void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
11- int ne0, int ne1, int ne2, int ne3,
12- int ne10, int ne11, int ne12, int ne13,
13- /* int s0, */ int s1, int s2, int s3,
14- /* int s00,*/ int s01, int s02, int s03,
15- /* int s10,*/ int s11, int s12, int s13,
16- const sycl::nd_item<3 > &item_ct1) {
17- const int i0s = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
18- item_ct1.get_local_id (2 );
19- const int i1 = (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
20- item_ct1.get_local_id (1 ));
21- const int i2 = (item_ct1.get_local_range (0 ) * item_ct1.get_group (0 ) +
22- item_ct1.get_local_id (0 )) /
23- ne3;
24- const int i3 = (item_ct1.get_local_range (0 ) * item_ct1.get_group (0 ) +
25- item_ct1.get_local_id (0 )) %
26- ne3;
27-
28- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
29- return ;
30- }
31-
32- const int i11 = i1 % ne11;
33- const int i12 = i2 % ne12;
34- const int i13 = i3 % ne13;
35-
36- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
37- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
38- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
39-
40- const src0_t * src0_row = src0 + i_src0;
41- const src1_t * src1_row = src1 + i_src1;
42- dst_t * dst_row = dst + i_dst;
43-
44- for (int i0 = i0s; i0 < ne0;
45- i0 += item_ct1.get_local_range (2 ) * item_ct1.get_group_range (2 )) {
46- const int i10 = i0 % ne10;
47- dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
11+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
12+ static __dpct_inline__ void k_bin_bcast_contiguous (const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
13+ dst_t * dst, std::size_t num_elements, const sycl::nd_item<1 > & it) {
14+ auto element_id = it.get_global_id (0 );
15+ auto global_range = it.get_global_range (0 );
16+ for (; element_id < num_elements; element_id += global_range) {
17+ auto src0_float_val = sycl::vec (src0[element_id]).template convert <float , sycl::rounding_mode::rte>();
18+ auto src1_float_val = sycl::vec (src1[element_id]).template convert <float , sycl::rounding_mode::rte>();
19+ float dst_val = bin_op (src0_float_val[0 ], src1_float_val[0 ]);
20+ auto val_to_store = sycl::vec (dst_val).template convert <dst_t , sycl::rounding_mode::rte>();
21+ dst[element_id] = val_to_store;
4822 }
4923}
5024
51- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
52- static void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst,
53- int ne0, int ne1, int ne2, int ne3,
54- int ne10, int ne11, int ne12, int ne13,
55- /* int s0, */ int s1, int s2, int s3,
56- /* int s00,*/ int s01, int s02, int s03,
57- /* int s10,*/ int s11, int s12, int s13,
58- const sycl::nd_item<3 > &item_ct1) {
59-
60- const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
61- item_ct1.get_local_id (2 );
62-
63- const int i3 = i/(ne2*ne1*ne0);
64- const int i2 = (i/(ne1*ne0)) % ne2;
65- const int i1 = (i/ne0) % ne1;
66- const int i0 = i % ne0;
67-
68- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
69- return ;
25+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
26+ static __dpct_inline__ void k_bin_bcast (const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
27+ int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
28+ int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
29+ int s11, int s12, int s13, std::size_t num_dst_elements,
30+ const sycl::nd_item<1 > & item_ct1) {
31+ auto calculate_logical_index =
32+ [](const std::array<int , 4 > & dims, std::size_t element_id) __attribute__ ((always_inline))->std ::array<int , 4 > {
33+ std::array<int , 4 > logical_index;
34+ #pragma unroll(4)
35+ for (int i = 3 ; i >= 0 ; i--) {
36+ logical_index[i] = element_id % dims[i];
37+ element_id /= dims[i];
38+ }
39+ return logical_index;
40+ };
41+
42+ auto calculate_index = [](const std::array<int , 4 > & dims, const std::array<int , 4 > & strides,
43+ const std::array<int , 4 > & indices) __attribute__ ((always_inline))
44+ ->std ::size_t {
45+ std::size_t index = 0 ;
46+ #pragma unroll(4)
47+ for (int i = 0 ; i < 4 ; i++) {
48+ auto index_i = indices[i];
49+ if (indices[i] >= dims[i]) {
50+ index_i = indices[i] % dims[i];
51+ }
52+ index += strides[i] * index_i;
53+ }
54+ return index;
55+ };
56+
57+ auto element_id = item_ct1.get_global_id (0 );
58+ for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range (0 )) {
59+ auto logical_index = calculate_logical_index ({ ne3, ne2, ne1, ne0 }, element_id);
60+ auto src_0_index = calculate_index ({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
61+ auto src_1_index = calculate_index ({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
62+ auto dst_index = calculate_index ({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
63+ auto src0_float_val = sycl::vec (src0[src_0_index]).template convert <float , sycl::rounding_mode::rte>();
64+ auto src1_float_val = sycl::vec (src1[src_1_index]).template convert <float , sycl::rounding_mode::rte>();
65+ float dst_val = bin_op (src0_float_val[0 ], src1_float_val[0 ]);
66+ auto val_to_store = sycl::vec (dst_val).template convert <dst_t , sycl::rounding_mode::rte>();
67+ dst[dst_index] = val_to_store;
7068 }
71-
72- const int i11 = i1 % ne11;
73- const int i12 = i2 % ne12;
74- const int i13 = i3 % ne13;
75-
76- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
77- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
78- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
79-
80- const src0_t * src0_row = src0 + i_src0;
81- const src1_t * src1_row = src1 + i_src1;
82- dst_t * dst_row = dst + i_dst;
83-
84- const int i10 = i0 % ne10;
85- dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
8669}
8770
88-
89- template <float (*bin_op)(const float , const float )>
90- struct bin_bcast_sycl {
71+ template <float (*bin_op)(const float , const float )> struct bin_bcast_sycl {
9172 template <typename src0_t , typename src1_t , typename dst_t >
9273 void operator ()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
9374 const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
@@ -96,165 +77,73 @@ struct bin_bcast_sycl {
9677 const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
9778 const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
9879 const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
99- int nr0 = ne10 / ne0;
100- int nr1 = ne11/ne1;
101- int nr2 = ne12/ne2;
102- int nr3 = ne13/ne3;
103-
104- int nr[4 ] = { nr0, nr1, nr2, nr3 };
105-
106- // collapse dimensions until first broadcast dimension
107- int64_t cne[] = {ne0, ne1, ne2, ne3};
108- int64_t cne0[] = {ne00, ne01, ne02, ne03};
109- int64_t cne1[] = {ne10, ne11, ne12, ne13};
110- size_t cnb[] = {nb0, nb1, nb2, nb3};
111- size_t cnb0[] = {nb00, nb01, nb02, nb03};
112- size_t cnb1[] = {nb10, nb11, nb12, nb13};
113- auto collapse = [](int64_t cne[]) {
114- cne[0 ] *= cne[1 ];
115- cne[1 ] = cne[2 ];
116- cne[2 ] = cne[3 ];
117- cne[3 ] = 1 ;
118- };
119-
120- auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
121- cnb[1 ] *= cne[1 ];
122- cnb[2 ] *= cne[2 ];
123- cnb[3 ] *= cne[3 ];
124- };
125-
126- if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
80+ auto check_bcast_required = [](const std::array<int64_t , 4 > & src_dims,
81+ const std::array<int64_t , 4 > & dst_dims) -> bool {
12782 for (int i = 0 ; i < 4 ; i++) {
128- if (nr[i] != 1 ) {
129- break ;
130- }
131- if (i > 0 ) {
132- collapse_nb (cnb, cne);
133- collapse_nb (cnb0, cne0);
134- collapse_nb (cnb1, cne1);
135- collapse (cne);
136- collapse (cne0);
137- collapse (cne1);
83+ if (dst_dims[i] > src_dims[i]) {
84+ return true ;
13885 }
13986 }
140- }
141- {
142- int64_t ne0 = cne[0 ];
143- int64_t ne1 = cne[1 ];
144- int64_t ne2 = cne[2 ];
145- int64_t ne3 = cne[3 ];
146-
147- int64_t ne10 = cne1[0 ];
148- int64_t ne11 = cne1[1 ];
149- int64_t ne12 = cne1[2 ];
150- int64_t ne13 = cne1[3 ];
151-
152- size_t nb0 = cnb[0 ];
153- size_t nb1 = cnb[1 ];
154- size_t nb2 = cnb[2 ];
155- size_t nb3 = cnb[3 ];
156-
157- size_t nb00 = cnb0[0 ];
158- size_t nb01 = cnb0[1 ];
159- size_t nb02 = cnb0[2 ];
160- size_t nb03 = cnb0[3 ];
161-
162- size_t nb10 = cnb1[0 ];
163- size_t nb11 = cnb1[1 ];
164- size_t nb12 = cnb1[2 ];
165- size_t nb13 = cnb1[3 ];
166-
167- size_t s0 = nb0 / sizeof (dst_t );
168- size_t s1 = nb1 / sizeof (dst_t );
169- size_t s2 = nb2 / sizeof (dst_t );
170- size_t s3 = nb3 / sizeof (dst_t );
171-
172- size_t s10 = nb10 / sizeof (src1_t );
173- size_t s11 = nb11 / sizeof (src1_t );
174- size_t s12 = nb12 / sizeof (src1_t );
175- size_t s13 = nb13 / sizeof (src1_t );
176-
177- size_t s00 = nb00 / sizeof (src0_t );
178- size_t s01 = nb01 / sizeof (src0_t );
179- size_t s02 = nb02 / sizeof (src0_t );
180- size_t s03 = nb03 / sizeof (src0_t );
181-
182- GGML_UNUSED (s00);
183-
184- GGML_ASSERT (nb0 % sizeof (dst_t ) == 0 );
185- GGML_ASSERT (nb1 % sizeof (dst_t ) == 0 );
186- GGML_ASSERT (nb2 % sizeof (dst_t ) == 0 );
187- GGML_ASSERT (nb3 % sizeof (dst_t ) == 0 );
188-
189- GGML_ASSERT (nb00 % sizeof (src0_t ) == 0 );
190- GGML_ASSERT (nb01 % sizeof (src0_t ) == 0 );
191- GGML_ASSERT (nb02 % sizeof (src0_t ) == 0 );
192- GGML_ASSERT (nb03 % sizeof (src0_t ) == 0 );
193-
194- GGML_ASSERT (nb10 % sizeof (src1_t ) == 0 );
195- GGML_ASSERT (nb11 % sizeof (src1_t ) == 0 );
196- GGML_ASSERT (nb12 % sizeof (src1_t ) == 0 );
197- GGML_ASSERT (nb13 % sizeof (src1_t ) == 0 );
198-
199- GGML_ASSERT (s0 == 1 );
200- GGML_ASSERT (s10 == 1 );
201-
202- const int block_size = 128 ;
203-
204- int64_t hne0 = std::max (ne0/2LL , 1LL );
205-
206- sycl::range<3 > block_dims (1 , 1 , 1 );
207- block_dims[2 ] = std::min<unsigned int >(hne0, block_size);
208- block_dims[1 ] = std::min<unsigned int >(
209- ne1, block_size / (unsigned int )block_dims[2 ]);
210- block_dims[0 ] = std::min (
211- std::min<unsigned int >(
212- ne2 * ne3, block_size / (unsigned int )block_dims[2 ] /
213- (unsigned int )block_dims[1 ]),
214- 64U );
215-
216- sycl::range<3 > block_nums (
217- (ne2 * ne3 + block_dims[0 ] - 1 ) / block_dims[0 ],
218- (ne1 + block_dims[1 ] - 1 ) / block_dims[1 ],
219- (hne0 + block_dims[2 ] - 1 ) / block_dims[2 ]);
220-
221- if (block_nums[0 ] > 65535 ) {
222- // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
223- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1 ) / block_size;
224- {
225- dpct::has_capability_or_fail (stream->get_device (),
226- {sycl::aspect::fp16});
227-
228- stream->parallel_for (
229- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , block_num) *
230- sycl::range<3 >(1 , 1 , block_size),
231- sycl::range<3 >(1 , 1 , block_size)),
232- [=](sycl::nd_item<3 > item_ct1) {
233- k_bin_bcast_unravel<bin_op>(
234- src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
235- ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
236- s03, s11, s12, s13, item_ct1);
237- });
238- }
239- } else {
240- /*
241- DPCT1049:16: The work-group size passed to the SYCL kernel may
242- exceed the limit. To get the device limit, query
243- info::device::max_work_group_size. Adjust the work-group size if
244- needed.
245- */
246- dpct::has_capability_or_fail (stream->get_device (),
247- {sycl::aspect::fp16});
248-
249- stream->parallel_for (
250- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
251- [=](sycl::nd_item<3 > item_ct1) {
252- k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253- ne2, ne3, ne10, ne11, ne12, ne13,
254- s1, s2, s3, s01, s02, s03, s11, s12, s13,
255- item_ct1);
256- });
257- }
87+ return false ;
88+ };
89+
90+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
91+
92+ GGML_ASSERT (nb0 % sizeof (dst_t ) == 0 );
93+ GGML_ASSERT (nb1 % sizeof (dst_t ) == 0 );
94+ GGML_ASSERT (nb2 % sizeof (dst_t ) == 0 );
95+ GGML_ASSERT (nb3 % sizeof (dst_t ) == 0 );
96+
97+ GGML_ASSERT (nb00 % sizeof (src0_t ) == 0 );
98+ GGML_ASSERT (nb01 % sizeof (src0_t ) == 0 );
99+ GGML_ASSERT (nb02 % sizeof (src0_t ) == 0 );
100+ GGML_ASSERT (nb03 % sizeof (src0_t ) == 0 );
101+
102+ GGML_ASSERT (nb10 % sizeof (src1_t ) == 0 );
103+ GGML_ASSERT (nb11 % sizeof (src1_t ) == 0 );
104+ GGML_ASSERT (nb12 % sizeof (src1_t ) == 0 );
105+ GGML_ASSERT (nb13 % sizeof (src1_t ) == 0 );
106+
107+ // dst strides in number of elements
108+ size_t s0 = nb0 / sizeof (dst_t );
109+ size_t s1 = nb1 / sizeof (dst_t );
110+ size_t s2 = nb2 / sizeof (dst_t );
111+ size_t s3 = nb3 / sizeof (dst_t );
112+
113+ // src1 strides in number of elements
114+ size_t s10 = nb10 / sizeof (src0_t );
115+ size_t s11 = nb11 / sizeof (src1_t );
116+ size_t s12 = nb12 / sizeof (src1_t );
117+ size_t s13 = nb13 / sizeof (src1_t );
118+
119+ // src0 strides in number of elements
120+ size_t s00 = nb00 / sizeof (src0_t );
121+ size_t s01 = nb01 / sizeof (src0_t );
122+ size_t s02 = nb02 / sizeof (src0_t );
123+ size_t s03 = nb03 / sizeof (src0_t );
124+
125+ std::size_t num_dst_elements = static_cast <std::size_t >(ne0) * static_cast <std::size_t >(ne1) *
126+ static_cast <std::size_t >(ne2) * static_cast <std::size_t >(ne3);
127+ std::size_t local_range = 256 ;
128+ std::size_t global_range = ceil_div (num_dst_elements, local_range) * local_range;
129+
130+ bool needs_broadcasting = check_bcast_required ({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
131+ check_bcast_required ({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
132+ bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
133+
134+ if (! needs_broadcasting && all_contiguous) {
135+ stream->submit ([&](sycl::handler & cgh) {
136+ cgh.parallel_for (sycl::nd_range<1 >({ global_range }, { local_range }), [=](sycl::nd_item<1 > it) {
137+ k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
138+ });
139+ });
140+ } else {
141+ stream->submit ([&](sycl::handler & cgh) {
142+ cgh.parallel_for (sycl::nd_range<1 >({ global_range }, { local_range }), [=](sycl::nd_item<1 > it) {
143+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
144+ s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
145+ });
146+ });
258147 }
259148 }
260149};
0 commit comments