-
Notifications
You must be signed in to change notification settings - Fork 0
/
kernel_resource_strings.h
664 lines (587 loc) · 23.2 KB
/
kernel_resource_strings.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
// IO data structure for kernel code;
static auto code_template_tensor_struct = R"(
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef short int int16_t;
typedef long long int int64_t;
template<typename T, int N>
struct Tensor {
__device__ T& operator[](int64_t ind) {
return data[ind];
};
T* data;
int64_t size[N];
int64_t stride[N];
};
// Specialization for 0-dim case as it does not need size and stride arrays.
// They will be an error as well since zero-length arrays are not allowed.
template<typename T>
struct Tensor<T, 0> {
__device__ T& operator[](int64_t) {
return *data;
};
T* data;
};
)";
// Code support for FP16 __half type and intrinsics
#ifdef __HIP_PLATFORM_HCC__
static auto code_fp16_support = R"()";
#else
static auto code_fp16_support = R"(
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
struct __align__(2) __half {
__host__ __device__ __half() { }
protected:
unsigned short __x;
};
/* Definitions of intrinsics */
__device__ __half __float2half(const float f) {
__half val;
asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f));
return val;
}
__device__ float __half2float(const __half h) {
float val;
asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
return val;
}
)";
#endif
// struct and code for functions that need random number generation
static auto code_random_number_gen = R"(
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
counter = make_uint4(0, 0, 0, 0);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);
STATE = 0;
incr_n(offset / 4);
}
__device__ inline unsigned long operator()() {
if(STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
for(int i = 0; i < 9; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
unsigned long ret;
switch(STATE) {
case 0: ret = output.x; break;
case 1: ret = output.y; break;
case 2: ret = output.z; break;
case 3: ret = output.w; break;
}
STATE = (STATE + 1) % 4;
return ret;
}
private:
uint4 counter;
uint4 output;
uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ inline void incr() {
if (++counter.x)
return;
if (++counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float uniform(unsigned int x) {
return x * M_RAN_INVM32;
}
)";
// Helper functions for Operations
static auto code_helper_funcs = R"(
__device__ constexpr int ceilDiv(const int a, const int b) {
return (a + b - 1) / b;
}
__device__ constexpr int alignBufferSize(const int buffer, const int size) {
return (buffer + (size-1)) & ~(size-1);
}
__device__ float clamp(const float x, const float minv, const float maxv) {
return x < minv ? minv : (x > maxv ? maxv : x);
}
__device__ float frac(const float x) {
return x - truncf(x);
}
__device__ float gelu(const float x) {
return x * normcdf(x);
}
__device__ float reciprocal(const float x) {
return 1.f / x;
}
__device__ float relu(const float x) {
return x <= 0.f ? 0.f : x;
}
__device__ float remainder(const float a, const float b) {
return a - b * floorf(a / b);
}
__device__ float sigmoid(const float x) {
return 1.f / (1.f + expf(-x));
}
__device__ float threshold(const float x, const float t, const float v) {
return x <= t ? v : x;
}
__device__ float where(const bool c, const float a, const float b) {
return c ? a : b;
}
__device__ float randLike(Philox rnd) {
return uniform(rnd());
};
)";
// Note: We agressively template functions taking dim3 in the functions below
// because ROCM uses different types for the various dim3 and maps them
// directly to intrinsics, but they're dim3 when used after modification.
/*
* EXAMPLE USAGE:
* blockReduceSum<X_THREADS, Y_THREADS, Z_THREADS>
* (output[output_index], inputs[input_index], [] __device__ (T& a, const T
* b) { a += b; } );
*/
static auto code_template_block_reduction = R"(
// [Z,Y,X]_THREADS is the number of participating threads in the z, y, x
// dimension of the block. If set to 0 it means that dimension doesn't
// participate, otherwise it is the number of threads. We could start with warp
// reductions, then reduce the warps, this could save some shared memory, but
// may actually be slower.
template<bool X_REDUCE, bool Y_REDUCE, bool Z_REDUCE, typename T, typename Func, typename _dim3ti, typename _dim3bd>
__inline__ __device__
void blockReduce(
T& out,
const T inp_val,
Func reduction_op,
const _dim3ti& thread_idx,
const _dim3bd& block_dim,
T* shared_mem,
bool read_write_pred,
T init_val) {
unsigned int reduction_size
= (X_REDUCE ? block_dim.x : 1)
* (Y_REDUCE ? block_dim.y : 1)
* (Z_REDUCE ? block_dim.z : 1);
// If this thread will output a final result
bool should_write = true;
if (X_REDUCE)
should_write = should_write && thread_idx.x == 0;
if (Y_REDUCE)
should_write = should_write && thread_idx.y == 0;
if (Z_REDUCE)
should_write = should_write && thread_idx.z == 0;
unsigned int reduction_stride;
unsigned int reduction_tid;
unsigned int linear_tid;
if(X_REDUCE && !Y_REDUCE && Z_REDUCE){
// Transpose Z and Y in the shared memory so Z and X dims are contiguous in smem
reduction_stride = 1;
linear_tid = threadIdx.y * blockDim.z * blockDim.x + threadIdx.z * blockDim.x + threadIdx.x;
reduction_tid = threadIdx.z * blockDim.x + threadIdx.x;
} else {
// Normal reduction in order
reduction_stride
= (X_REDUCE ? 1
: (Y_REDUCE ? block_dim.x
: (Z_REDUCE ? block_dim.x * block_dim.y : 0)));
linear_tid = thread_idx.z * block_dim.y * block_dim.x + thread_idx.y * block_dim.x + thread_idx.x;
reduction_tid
= ( Z_REDUCE ? thread_idx.z : 0 ) * ( Y_REDUCE ? block_dim.y : 1 ) * ( X_REDUCE ? block_dim.x : 1 )
+ ( Y_REDUCE ? thread_idx.y : 0 ) * ( X_REDUCE ? block_dim.x : 1 )
+ ( X_REDUCE ? thread_idx.x : 0 );
}
assert( reduction_stride != 0 );
if(read_write_pred){
shared_mem[linear_tid] = inp_val;
} else {
shared_mem[linear_tid] = init_val;
}
__syncthreads();
// Reduce down to nearest power of 2:
int np2 = 1 << (31 - __clz(reduction_size));
if( reduction_tid < np2 ){
if( reduction_tid + np2 < reduction_size){
reduction_op( shared_mem[linear_tid], shared_mem[linear_tid + np2 * reduction_stride] );
}
}
__syncthreads();
//for (int factor = np2/2; factor > contig_threads / 2; factor>>=1) {
for (int factor = np2/2; factor > 0; factor>>=1) {
if (reduction_tid < factor) {
reduction_op( shared_mem[linear_tid], shared_mem[linear_tid + factor * reduction_stride] );
}
__syncthreads();
}
if(should_write && read_write_pred)
out = shared_mem[linear_tid];
}
)";
/**
Inter-block reduction.
Function gridReduce performs point-wise reductions of scalars across thread
blocks. Thread blocks are disjointly partitioned into groups of thread blocks,
"reduction segments," that are collectively defined by boolean template
parameters, X_BLOCK, Y_BLOCK and Z_BLOCK. Each of X/Y/Z_BLOCK determines
whether thread blocks along the dimension should be grouped into the same
reduction segment. Cross-block reducitons are independently done within each
segment and generates distinctive results per segment. For instance, if all of
X/Y/Z_BLOCK are true, reductions will be done across all thread blocks since
there will be just a single segment consisting of all thread blocks. If none
of them are true, each thread block will become a segment by itself, so no
reduction will be performed.
The input scalars to reduce within each segment are a certain subset of
thread-private scalars provided as part of the gridReduce function parameters.
Boolean template parameters, X_THREAD, Y_THREAD and Z_THREAD, determine which
subset of the scalars should be used for inter-block reductions. Specifically,
all the input scalars of threads along each dimension will be used when
X/Y/Z_THREAD are true. Otherwise, only the value held at offset 0 of each
dimension will be used. Thus, for example, if all of X/Y/Z_THREAD are true,
the scalars of all threads in each block will participate in inter-block
reductions. If all of them are false, only one scalar of the thread at
threadIdx.x == threadIdx.y == threadIdx.z == 0 will be used. In the code
below, we call the subset of threads a "reduction block."
Inter-block reductions perform point-wise reductions of scalars of reduction
blocks within each reduction segment. More specifically, let rb be a reduction
block and rs be a reduction segment. Let IN(thread_idx, block_idx) denote the
input scalar of thread at thread_idx and block_idx. The result of each
reduction segment, OUT(thread_idx, block_idx_out), is defined only for each
thread_idx in thread block block_idx_out in the segment as follows:
OUT(thread_idx, block_idx_out) = Reduction of IN(thread_idx, block_idx) for
all block_idx in a reduction segment
OUT is not given for all threads that are not in block_idx_out and the
reduction block.
See also the function comment of gridReduce.
*/
static auto code_template_grid_reduction = R"(
namespace reduction {
// Utility functions
template<typename _dim3>
__host__ __device__ __forceinline__ size_t size(const _dim3& d) {
return (size_t)d.x * (size_t)d.y * (size_t)d.z;
}
#define isize(d) d.x * d.y * d.z
template<typename _dim3pos, typename _dim3dim>
__host__ __device__ __forceinline__ size_t offset(const _dim3pos& pos, const _dim3dim& dim) {
return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x +
(size_t)pos.z * (size_t)dim.x * (size_t)dim.y;
}
#define ioffset(pos, dim) pos.x + pos.y * dim.x + pos.z * dim.x * dim.y
// Returns dim3 of each reduction segment.
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
__host__ __device__ dim3 dimension_of_reduction_segment(const _dim3& grid_dim) {
return dim3{X_BLOCK ? grid_dim.x : 1,
Y_BLOCK ? grid_dim.y : 1,
Z_BLOCK ? grid_dim.z : 1};
}
// Returns the number of blocks in each reduction segment.
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
__host__ __device__ size_t size_of_reduction_segment(const _dim3& grid_dim) {
return size(dimension_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(grid_dim));
}
// Returns the total number of reduction segments.
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
__host__ __device__ size_t number_of_reduction_segments(const _dim3& grid_dim) {
return (X_BLOCK ? 1: grid_dim.x) *
(Y_BLOCK ? 1 : grid_dim.y) *
(Z_BLOCK ? 1 : grid_dim.z);
}
// Returns the 1-D index of the segment of thread block of block_idx.
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3bi, typename _dim3gd>
__host__ __device__ size_t index_of_reduction_segment(const _dim3bi& block_idx,
const _dim3gd& grid_dim) {
size_t seg_idx = 0;
if (!Z_BLOCK)
seg_idx += block_idx.z;
if (!Y_BLOCK)
seg_idx = seg_idx * grid_dim.y + block_idx.y;
if (!X_BLOCK)
seg_idx = seg_idx * grid_dim.x + block_idx.x;
return seg_idx;
}
// Returns the offset of thread block in its reduction segment.
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3bi, typename _dim3gd>
__host__ __device__ size_t offset_in_reduction_segment(const _dim3bi& block_idx,
const _dim3gd& grid_dim) {
size_t offset = 0;
if (Z_BLOCK)
offset = offset * grid_dim.z + block_idx.z;
if (Y_BLOCK)
offset = offset * grid_dim.y + block_idx.y;
if (X_BLOCK)
offset = offset * grid_dim.x + block_idx.x;
return offset;
}
// Returns dim3 of each reduction block.
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename _dim3>
__host__ __device__ dim3 dimension_of_reduction_block(const _dim3& block_dim) {
return dim3{X_THREAD ? block_dim.x : 1,
Y_THREAD ? block_dim.y : 1,
Z_THREAD ? block_dim.z : 1};
}
// Returns the number of threads of each reduction block.
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename _dim3>
__host__ __device__ int size_of_reduction_block(const _dim3& block_dim) {
auto tmp_dim = dimension_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(block_dim);
return isize(tmp_dim);
}
// Returns the linear offset of a thread in a reduction block.
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename _dim3ti, typename _dim3bd>
__host__ __device__ int offset_in_reduction_block(const _dim3ti& thread_idx,
const _dim3bd& block_dim) {
int offset = 0;
if (Z_THREAD)
offset += thread_idx.z;
if (Y_THREAD)
offset = offset * block_dim.y + thread_idx.y;
if (X_THREAD)
offset = offset * block_dim.x + thread_idx.x;
return offset;
}
/** Reduces all the reduction blocks in each reduction segment.
This is only used by one thread block per reduction segment. The input
reduction blocks of the segment are stored in an intermediate buffer pointed
by parameter in. Template parameters X/Y/Z_THREAD denote how the reduction
block is formed.
The size of a reduction block is by definition smaller or equal to the size of
a thread block. We use the remaining threads to parallelize reductions across
reduction blocks. For example, when X/Y/Z_THREAD = {true, false, false}, we
use blockDim.y*blockDim.z threads for each output value. This is done first by
loading the input values in parallel and then by reducing across threads of
dimensions whose XYZ_THREAD are false.
Note that what is done here after the loading from global memory is similar to
what the existing blockReduce function does. The main difference is that the
logical block to reduce is a 2D domain where the leading dimension is the size
of a reduction block and the second dimension is the remaining factor in each
thread block. For example, when X/Y/Z_THREAD = {false, true, false}, the
threads are arranged as (blockDim.y, blockDim.x*blockDim.z). We do not reduce
along the first dimension but only the second dimension. So, it is possible to
reuse the existing blockReduce with dim3{blockDim.y, blockDim.x*blockDim.z}
instead of blockDim and with X_THREAD and Y_THREAD being false and true,
respectively. Also, it still need to shuffle the final output values to their
actual corresponding threads. In the case of when X/Y/Z_THREAD = {false, true,
false}, after the intra-block reduction, the final results will still be held
by the first blockDim.y threads, which need to be transferred to threads at
threadIdx.x == 0 and threadIdx.z == 0.
*/
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD,
typename T, typename Func>
__device__ void gridReduceLastBlock(
T& out,
const T *in,
const size_t in_size,
Func reduction_op,
T* shared_buf,
bool read_write_pred,
T init_val) {
const int tid = ioffset(threadIdx, blockDim);
const int block_size = isize(blockDim);
const int rblock_size = size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
T inp = init_val;
if (tid < in_size) {
inp = in[tid];
}
for (size_t i = tid + block_size; i < in_size; i += block_size) {
reduction_op(inp, in[i]);
}
const auto should_write = (X_THREAD || threadIdx.x == 0) &&
(Y_THREAD || threadIdx.y == 0) &&
(Z_THREAD || threadIdx.z == 0);
auto rem_size = block_size / rblock_size;
if (rem_size > 1) {
const int rblock_offset = tid % rblock_size;
const int rblock_idx = tid / rblock_size;
blockReduce<false, true, false>(
inp, inp, reduction_op,
dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0},
dim3{(unsigned)rblock_size, (unsigned)rem_size},
shared_buf, true, init_val);
__syncthreads();
if (tid < rblock_size) {
shared_buf[tid] = inp;
}
__syncthreads();
if (should_write) {
inp = shared_buf[offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
threadIdx, blockDim)];
}
}
if (should_write && read_write_pred) {
out = inp;
}
}
/** Reduces per-thread values across thread blocks.
Function parameters:
- out: Per-thread output location
- inp_val: Per-thread input value
- reduction_op: Scalar reduction function
- work_buf: Temporary buffer for cross-block reductions
- sync_flags: A vector of integers for synchronizations
- shared_buf: Shared memory buffer for intra-block reduction
Return true when the thread block has the valid result.
Template parameters:
- X/Y/Z_BLOCK: When true, reduces across thread blocks along the X/Y/Z
dimensions
- X/Y/Z_THREAD: When true, all threads along the X/Y/Z dimensions participate in
the cross-block reduction. Otherwise, only threads at offset 0 do.
- T: Scalar data type of input/output data
- Func: Type of scalara reduction function
Template parameters X/Y/Z_BLOCK define a group of thread blocks that are reduced together. We call
it a reduction segment. Some examples are:
Case 1: X/Y/Z_BLOCK == true/true/true -> There is only one segment, which includes all
thread blocks. It is effecively the same as the grid.
Case 2: X/Y/Z_BLOCK == false/false/false -> Each thread block comprises an individual
segment by itself.
Case 3: X/Y/Z_BLOCK == true/false/false -> Each segment contains thread blocks that have
the same blockDim.x. There will be blockDim.y*blockDim.z such segments.
X/Y/Z_THREAD defines a sub region of a thread block that should be reduced with
the sub regions of other thread blocks. We call it a reduction block. E.g.,
Case 1: X/Y/Z_THREAD == false/false/false -> Only thread 0 participates in the
cross-block reductions. The reduction block is 1x1x1 with thread 0.
Case 2: X/Y/Z_THREAD == true/true/true-> All threads in a thread block participate in
the cross-block reductions. The reduction block in this case is equivalent to
the thread block.
After the function completes, only one thread block per reduction segment gets
valid reduction results. There is no guarantee which particular block gets the
final results.
*/
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK,
bool X_THREAD, bool Y_THREAD, bool Z_THREAD,
typename T, typename Func>
__device__ bool gridReduce(T& out, T inp_val, Func reduction_op,
volatile T* work_buf,
Tensor<int64_t, 1> sync_flags,
T* shared_buf, bool read_write_pred, T init_val) {
// Number of values to reduce in the grid dimensions
const auto seg_size =
size_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
// Index of the reduction we're performing out of the seg_size
const auto seg_idx =
index_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
// Number of threads we can use in final reduction, Seems to assume all threads in the block participate
const auto rblock_size =
size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
// advance to the offset for this segment
// index of reduction * size of the reduction * size of threads
work_buf += seg_idx * seg_size * rblock_size;
if ((X_THREAD || threadIdx.x == 0) &&
(Y_THREAD || threadIdx.y == 0) &&
(Z_THREAD || threadIdx.z == 0)) {
auto rblock_offset =
offset_in_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
auto thread_offset =
offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(threadIdx, blockDim);
auto work_buf_offset = rblock_size * rblock_offset + thread_offset;
if(read_write_pred){
work_buf[work_buf_offset] = inp_val;
} else {
work_buf[work_buf_offset] = init_val;
}
}
__syncthreads();
__shared__ bool last_block;
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
__threadfence();
// printf("%ld\n", sync_flags[seg_idx]);
auto old = (int64_t) atomicAdd( (unsigned long long*) &sync_flags[seg_idx], 1);
last_block = old + 1 == seg_size;
// printf("Last_block = %d + 1 == %d\n", (int)old, (int)seg_size);
}
__syncthreads();
if (last_block) {
// printf("Last block %d %d %d %d\n", blockIdx.x, blockIdx.y, blockIdx.z);
// final reduction
gridReduceLastBlock<X_THREAD, Y_THREAD, Z_THREAD>(
out, (T*)work_buf, seg_size * rblock_size,
reduction_op, shared_buf, read_write_pred, init_val);
return true;
} else {
// printf("Not last block %d %d %d\n", blockIdx.x, blockIdx.y, blockIdx.z);
return false;
}
}
} // namespace reduction
)";
static auto code_template_block_broadcast = R"(
namespace broadcast {
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
__host__ __device__ unsigned offset_of_source(const dim3& block_dim, const dim3& thread_idx) {
unsigned offset = 0;
if (!Z_THREAD)
offset = offset * block_dim.z + thread_idx.z;
if (!Y_THREAD)
offset = offset * block_dim.y + thread_idx.y;
if (!X_THREAD)
offset = offset * block_dim.x + thread_idx.x;
return offset;
}
/** Broadcasts within partitioned groups of threads.
X_THREAD: Broadcast from threadIdx.x == 0 if true
Y_THREAD: Broadcast from threadIdx.y == 0 if true
Z_THREAD: Broadcast from threadIdx.z == 0 if true
inp_val: Per-thread source value. Only valid when the thread is a source.
out: Per-thread output location
*/
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T>
__device__ void blockBroadcast(T& out, T inp_val, T* shared_mem) {
const bool has_valid_data =
(!X_THREAD || threadIdx.x == 0) &&
(!Y_THREAD || threadIdx.y == 0) &&
(!Z_THREAD || threadIdx.z == 0);
const auto shared_offset = offset_of_source<X_THREAD, Y_THREAD, Z_THREAD>(blockDim, threadIdx);
if (has_valid_data)
shared_mem[shared_offset] = inp_val;
__syncthreads();
out = shared_mem[shared_offset];
}
} // namespace broadcast
)";
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch