|
3 | 3 | static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, |
4 | 4 | const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { |
5 | 5 |
|
6 | | - const int nrows = item_ct1.get_group_range(2); |
| 6 | + const int nrows = item_ct1.get_group_range(0); |
7 | 7 | const int nchannels = item_ct1.get_group_range(1); |
8 | 8 | const int nthreads = item_ct1.get_local_range(2); |
9 | | - const int sample = item_ct1.get_group(0); |
| 9 | + const int sample = item_ct1.get_group(2); |
10 | 10 | const int channel = item_ct1.get_group(1); |
11 | | - const int row = item_ct1.get_group(2); |
| 11 | + const int row = item_ct1.get_group(0); |
12 | 12 |
|
13 | 13 | const int tid = item_ct1.get_local_id(2); |
14 | 14 | const int nwarps = nthreads / WARP_SIZE; |
@@ -140,11 +140,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con |
140 | 140 | static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, |
141 | 141 | const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { |
142 | 142 |
|
143 | | - const int nrows = item_ct1.get_group_range(2); |
| 143 | + const int nrows = item_ct1.get_group_range(0); |
144 | 144 | const int nchannels = item_ct1.get_group_range(1); |
145 | | - const int sample = item_ct1.get_group(0); |
| 145 | + const int sample = item_ct1.get_group(2); |
146 | 146 | const int channel = item_ct1.get_group(1); |
147 | | - const int row = item_ct1.get_group(2); |
| 147 | + const int row = item_ct1.get_group(0); |
148 | 148 | const int nthreads = item_ct1.get_local_range(2); |
149 | 149 |
|
150 | 150 | const int tid = item_ct1.get_local_id(2); |
@@ -237,10 +237,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i |
237 | 237 | const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, |
238 | 238 | const float eps, queue_ptr stream, int device) { |
239 | 239 |
|
240 | | - const sycl::range<3> global_dims(nsamples, nchannels, nrows); |
| 240 | + const sycl::range<3> global_dims(nrows, nchannels, nsamples); |
241 | 241 | GGML_ASSERT(ncols % WARP_SIZE == 0); |
242 | 242 | if (ncols < 1024) { |
243 | | - const sycl::range<3> block_dims(1, 1, WARP_SIZE); // Equivalent to CUDA's (WARP_SIZE, 1, 1) |
| 243 | + const sycl::range<3> block_dims(1, 1, WARP_SIZE); |
244 | 244 | stream->submit([&](sycl::handler& cgh) { |
245 | 245 | cgh.parallel_for( |
246 | 246 | sycl::nd_range<3>(global_dims * block_dims, block_dims), |
@@ -324,7 +324,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const |
324 | 324 | GGML_ASSERT(ncols % WARP_SIZE == 0); |
325 | 325 | // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); |
326 | 326 |
|
327 | | - const sycl::range<3> global_dims(nsamples, nchannels, nrows); |
| 327 | + const sycl::range<3> global_dims(nrows, nchannels, nsamples); |
328 | 328 | if (ncols < 1024) { |
329 | 329 | const sycl::range<3> block_dims(1, 1, WARP_SIZE); |
330 | 330 | stream->submit([&](sycl::handler& cgh) { |
|
0 commit comments