-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
layernorm_backward.cu
1552 lines (1392 loc) · 70.3 KB
/
layernorm_backward.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
Kernels for layernorm backward pass.
Compile example:
nvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_backward.cu -o layernorm_backward
version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C
./layernorm_backward 1
version 2 moves a lot of reduction to shared memory over global memory
./layernorm_backward 2
*/
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#include <assert.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#define ENABLE_BF16
#include "common.h"
// ----------------------------------------------------------------------------
// CPU code reference
void layernorm_forward_cpu(float* out, float* mean, float* rstd,
const float* inp, const float* weight, const float* bias,
int B, int T, int C) {
// reference: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
// both inp and out are (B,T,C) of the activations
// mean and rstd are (B,T) buffers, to be used later in backward pass
// at each position (b,t) of the input, the C-dimensional vector
// of activations gets normalized, then scaled and shifted
float eps = 1e-5f;
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// seek to the input position inp[b,t,:]
const float* x = inp + b * T * C + t * C;
// calculate the mean
float m = 0.0f;
for (int i = 0; i < C; i++) {
m += x[i];
}
m = m/C;
// calculate the variance (without any bias correction)
float v = 0.0f;
for (int i = 0; i < C; i++) {
float xshift = x[i] - m;
v += xshift * xshift;
}
v = v/C;
// calculate the rstd (reciprocal standard deviation)
float s = 1.0f / sqrtf(v + eps);
// seek to the output position in out[b,t,:]
float* out_bt = out + b * T * C + t * C;
for (int i = 0; i < C; i++) {
float n = (s * (x[i] - m)); // normalize
float o = n * weight[i] + bias[i]; // scale and shift
out_bt[i] = o; // write
}
// cache the mean and rstd for the backward pass later
mean[b * T + t] = m;
rstd[b * T + t] = s;
}
}
}
void layernorm_backward_cpu(float* dinp, float* dweight, float* dbias,
const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
int B, int T, int C) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
const float* dout_bt = dout + b * T * C + t * C;
const float* inp_bt = inp + b * T * C + t * C;
float* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = mean[b * T + t];
const float rstd_bt = rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
// gradient contribution to bias
dbias[i] += dout_bt[i];
// gradient contribution to weight
dweight[i] += norm_bti * dout_bt[i];
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] += dval;
}
}
}
}
// ----------------------------------------------------------------------------
// GPU kernels
// GPU helper functions for atomicAdd on smaller than 32-bit types
#ifdef ENABLE_BF16
__device__ void atomicAddX(__nv_bfloat16* addr, __nv_bfloat16 val) {
uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);
__nv_bfloat162* ptr_bf16 = reinterpret_cast<__nv_bfloat162*>(ptr_val & ~uintptr_t(0x3));
// Prepare the value to add, setting the other half to zero
__nv_bfloat162 add_val = (ptr_val & 0x3) ? __halves2bfloat162(__ushort_as_bfloat16(0), val)
: __halves2bfloat162(val, __ushort_as_bfloat16(0));
atomicAdd(ptr_bf16, add_val);
}
#endif
#ifdef ENABLE_FP16
__device__ void atomicAddX(half* addr, half val) {
uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);
half2* ptr_fp16 = reinterpret_cast<half2*>(ptr_val & ~uintptr_t(0x3));
// Prepare the value to add, setting the other half to zero
half2 add_val = (ptr_val & 0x3) ? __halves2half2(__ushort_as_half(0), val)
: __halves2half2(val, __ushort_as_half(0));
atomicAdd(ptr_fp16, add_val);
}
#endif
__device__ void atomicAddX(float* addr, float val) {
atomicAdd(addr, val);
}
// super naive kernel that just parallelizes over B,T and loops over C
__global__ void layernorm_backward_kernel1(float* dinp, float* dweight, float* dbias,
const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B*T) return;
int b = idx / T;
int t = idx % T;
const float* dout_bt = dout + b * T * C + t * C;
const float* inp_bt = inp + b * T * C + t * C;
float* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = mean[b * T + t];
const float rstd_bt = rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
// gradient contribution to bias
atomicAdd(&dbias[i], dout_bt[i]);
// gradient contribution to weight
atomicAdd(&dweight[i], norm_bti * dout_bt[i]);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] += dval;
}
}
// uses shared memory instead for the reduces
template <typename Tdinp, typename Tparams, typename Tdout, typename Trest>
__global__ void layernorm_backward_kernel2(Tdinp* dinp, Tparams* dweight, Tparams* dbias,
const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,
int B, int T, int C, float* dweight_tmp, float* dbias_tmp) {
extern __shared__ float shared[]; // size = 2 * C
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
int N = B * T;
if(idx >= N) { return; } // thread guards
int b = idx / T;
int t = idx % T;
const Tdout* dout_bt = dout + b * T * C + t * C;
const Trest* inp_bt = inp + b * T * C + t * C;
Tdinp* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
// init shared memory to zero
#pragma unroll
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
__syncthreads();
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * (float)dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * (float)dout_bt[i];
// gradient contribution to bias
atomicAdd(&dbias_shared[i], (float)dout_bt[i]);
// gradient contribution to weight
atomicAdd(&dweight_shared[i], norm_bti * (float)dout_bt[i]);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);
}
__syncthreads();
// write to global memory
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
atomicAdd(&dbias_tmp[i], dbias_shared[i]);
atomicAdd(&dweight_tmp[i], dweight_shared[i]);
}
}
template <typename Tparams>
__global__ void copy_to_dweight_dbias(int C, Tparams* dbias, Tparams* dweight, float* dbias_tmp, float* dweight_tmp) {
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < C; i += blockDim.x * gridDim.x) {
dbias[i] = (Tparams)dbias_tmp[i];
dweight[i] = (Tparams)dweight_tmp[i];
}
}
// kernel2 is 1 threadblock for all Cs on 32 BTs (assuming threadblock size of 1024 threads = 32 warps)
// To minimise the amount of atomicAdds, we will aim for 1 threadblock per SM, processing (total BTs / threadblocks) BTs
template <typename Tdinp, typename Tparams, typename Tdout, typename Trest>
__global__ void layernorm_backward_kernel3(Tdinp* dinp, Tparams* dweight, Tparams* dbias,
const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
// init shared memory to zero
#pragma unroll 4
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
__syncthreads();
int warps_in_grid = gridDim.x * warp.meta_group_size();
for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {
int b = idx / T;
int t = idx % T;
const Tdout* dout_bt = dout + b * T * C + t * C;
const Trest* inp_bt = inp + b * T * C + t * C;
Tdinp* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * (float)dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float dout_i = (float)__ldcs(&dout_bt[i]);
float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * dout_i;
// gradient contribution to bias
atomicAdd(&dbias_shared[i], dout_i);
// gradient contribution to weight
atomicAdd(&dweight_shared[i], norm_bti * dout_i);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);
}
}
__syncthreads();
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
atomicAddX(&dbias[i], (Tparams)dbias_shared[i]);
atomicAddX(&dweight[i], (Tparams)dweight_shared[i]);
}
}
// atomicCAS version of kernel3
template <typename Tdinp, typename Tparams, typename Tdout, typename Trest>
__global__ void layernorm_backward_kernel4(Tdinp* dinp, Tparams* dweight, Tparams* dbias,
const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
// init shared memory to zero
#pragma unroll 4
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
__syncthreads();
int warps_in_grid = gridDim.x * warp.meta_group_size();
for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {
int b = idx / T;
int t = idx % T;
const Tdout* dout_bt = dout + b * T * C + t * C;
const Trest* inp_bt = inp + b * T * C + t * C;
Tdinp* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * (float)dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float dout_i = (float)__ldcs(&dout_bt[i]);
float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * dout_i;
// gradient contribution to bias
atomicAdd(&dbias_shared[i], dout_i);
// gradient contribution to weight
atomicAdd(&dweight_shared[i], norm_bti * dout_i);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);
}
}
__syncthreads();
__nv_bfloat162* dbiasVec2 = reinterpret_cast<__nv_bfloat162*>(dbias);
__nv_bfloat162* dweightVec2 = reinterpret_cast<__nv_bfloat162*>(dweight);
// write to global memory
for(int i = threadIdx.x; i < C/2; i+= blockDim.x) {
__nv_bfloat162 add_dbias = __halves2bfloat162((__nv_bfloat16)dbias_shared[i*2], (__nv_bfloat16)dbias_shared[i*2+1]);
__nv_bfloat162 add_dweight = __halves2bfloat162((__nv_bfloat16)dweight_shared[i*2], (__nv_bfloat16)dweight_shared[i*2+1]);
// Get the current value from L2 cache
__nv_bfloat162 current_dbias = __ldcg(&dbiasVec2[i]);
__nv_bfloat162 current_dweight = __ldcg(&dweightVec2[i]);
// Add the two values
__nv_bfloat162 new_dbias = add_dbias + current_dbias;
__nv_bfloat162 new_dweight = add_dweight + current_dweight;
// Write the result back to L2 cache using 32-bit integer atomic compare and exchange
unsigned int current_dbias32b = *reinterpret_cast<unsigned int*>(¤t_dbias);
unsigned int current_dweight32b = *reinterpret_cast<unsigned int*>(¤t_dweight);
unsigned int new_dbias32b = *reinterpret_cast<unsigned int*>(&new_dbias);
unsigned int new_dweight32b = *reinterpret_cast<unsigned int*>(&new_dweight);
unsigned int old_dbias32b = atomicCAS((unsigned int*)&dbiasVec2[i], current_dbias32b, new_dbias32b);
unsigned int old_dweight32b = atomicCAS((unsigned int*)&dweightVec2[i], current_dweight32b, new_dweight32b);
// If the value has changed between read and atomic, we need to try again
while (old_dbias32b != current_dbias32b) {
current_dbias32b = old_dbias32b;
new_dbias = *reinterpret_cast<__nv_bfloat162*>(¤t_dbias32b) + add_dbias;
new_dbias32b = *reinterpret_cast<unsigned int*>(&new_dbias);
old_dbias32b = atomicCAS((unsigned int*)&dbiasVec2[i], current_dbias32b, new_dbias32b);
}
while (old_dweight32b != current_dweight32b) {
current_dweight32b = old_dweight32b;
new_dweight = *reinterpret_cast<__nv_bfloat162*>(¤t_dweight32b) + add_dweight;
new_dweight32b = *reinterpret_cast<unsigned int*>(&new_dweight);
old_dweight32b = atomicCAS((unsigned int*)&dweightVec2[i], current_dweight32b, new_dweight32b);
}
}
}
// FP32 scratchpad per threadgroup, zero atomics except atomicAdd on unsigned int for the flag (based on kernel3)
template <typename Tdinp, typename Tparams, typename Tdout, typename Trest>
__global__ void layernorm_backward_kernel5(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,
const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C + 1
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
// init shared memory to zero
#pragma unroll 4
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
unsigned int *tmp_flag = (unsigned int*)(shared + C*2);
__syncthreads();
int warps_in_grid = gridDim.x * warp.meta_group_size();
for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {
int b = idx / T;
int t = idx % T;
const Tdout* dout_bt = dout + b * T * C + t * C;
const Trest* inp_bt = inp + b * T * C + t * C;
Tdinp* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * (float)dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float dout_i = (float)__ldcs(&dout_bt[i]);
float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * dout_i;
// gradient contribution to bias
atomicAdd(&dbias_shared[i], dout_i);
// gradient contribution to weight
atomicAdd(&dweight_shared[i], norm_bti * dout_i);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);
}
}
__syncthreads();
float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C * gridDim.x;
unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C * gridDim.x));
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
scratch_dbias[i + C*blockIdx.x] = dbias_shared[i];
scratch_dweight[i + C*blockIdx.x] = dweight_shared[i];
}
__threadfence();
__syncthreads();
if (threadIdx.x == 0) {
*tmp_flag = atomicAdd(scratchFlag, 1);
}
__syncthreads();
if (*tmp_flag == gridDim.x-1) {
// last block to finish, accumulate the scratchpad
for (int i = threadIdx.x; i < C; i += blockDim.x) {
float dbias_sum = 0.0f;
float dweight_sum = 0.0f;
#pragma unroll 8
for (int j = 0; j < gridDim.x; j++) {
dbias_sum += scratch_dbias[i + j*C];
dweight_sum += scratch_dweight[i + j*C];
}
dbias[i] = (Tparams)((float)dbias[i] + dbias_sum);
dweight[i] = (Tparams)((float)dweight[i] + dweight_sum);
}
}
}
// single FP32 scratchpad shared by all the threadblocks (based on kernels 3 & 5)
template <typename Tdinp, typename Tparams, typename Tdout, typename Trest>
__global__ void layernorm_backward_kernel6(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,
const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C + 1
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
// init shared memory to zero
#pragma unroll 4
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
unsigned int *tmp_flag = (unsigned int*)(shared + C*2);
__syncthreads();
int warps_in_grid = gridDim.x * warp.meta_group_size();
for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {
int b = idx / T;
int t = idx % T;
const Tdout* dout_bt = dout + b * T * C + t * C;
const Trest* inp_bt = inp + b * T * C + t * C;
Tdinp* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * (float)dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float dout_i = (float)__ldcs(&dout_bt[i]);
float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * dout_i;
// gradient contribution to bias
atomicAdd(&dbias_shared[i], dout_i);
// gradient contribution to weight
atomicAdd(&dweight_shared[i], norm_bti * dout_i);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);
}
}
// Accumulate into a FP32 scratchpad
// BF16 atomics are potentially much slower... and this is more precise!
__syncthreads();
float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C;
unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C));
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
atomicAdd(&scratch_dbias[i], dbias_shared[i]);
atomicAdd(&scratch_dweight[i], dweight_shared[i]);
}
__syncthreads();
if (threadIdx.x == 0) {
*tmp_flag = atomicAdd(scratchFlag, 1);
}
__syncthreads();
if (*tmp_flag == gridDim.x-1) {
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
// todo - potentially do stochastic rounding here as well
dbias[i] = (Tparams)scratch_dbias[i];
dweight[i] = (Tparams)scratch_dweight[i];
}
}
}
// Same as kernel 6 but without cooperative groups or templates
__global__ void layernorm_backward_kernel7(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C + 1
int warpId = threadIdx.x / warpSize; // warp index within a block
int warpsInBlock = blockDim.x / warpSize;
int base_idx = blockIdx.x * warpsInBlock + warpId;
int warpThreadIdx = threadIdx.x % warpSize; // Thread index within the warp
int warps_in_grid = gridDim.x * warpsInBlock;
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
// init shared memory to zero
#pragma unroll 4
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
unsigned int *tmp_flag = (unsigned int*)(shared + C*2);
__syncthreads();
for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {
int b = idx / T;
int t = idx % T;
const floatX* dout_bt = dout + b * T * C + t * C;
const floatX* inp_bt = inp + b * T * C + t * C;
floatX* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warpThreadIdx; i < C; i += warpSize) {
float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * (float)dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = warpReduceSum(dnorm_mean);
dnorm_norm_mean = warpReduceSum(dnorm_norm_mean);
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = warpThreadIdx; i < C; i += warpSize) {
float dout_i = (float)__ldcs(&dout_bt[i]);
float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;
float dnorm_i = (float)weight[i] * dout_i;
// gradient contribution to bias
atomicAdd(&dbias_shared[i], dout_i);
// gradient contribution to weight
atomicAdd(&dweight_shared[i], norm_bti * dout_i);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] = (floatX)((float)dinp_bt[i] + dval);
}
}
// Accumulate into a FP32 scratchpad
// BF16 atomics are potentially much slower... and this is more precise!
__syncthreads();
float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C;
unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C));
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
atomicAdd(&scratch_dbias[i], dbias_shared[i]);
atomicAdd(&scratch_dweight[i], dweight_shared[i]);
}
__syncthreads();
if (threadIdx.x == 0) {
*tmp_flag = atomicAdd(scratchFlag, 1);
}
__syncthreads();
if (*tmp_flag == gridDim.x-1) {
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
// todo - potentially do stochastic rounding here as well
dbias[i] = (floatX)scratch_dbias[i];
dweight[i] = (floatX)scratch_dweight[i];
}
}
}
__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
layernorm_backward_kernel8(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
const floatX* dout, const floatX* inp, const floatX* weight,
const floatX* mean, const floatX* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C + 1
int warpId = threadIdx.x / warpSize; // warp index within a block
int warpsInBlock = blockDim.x / warpSize; //number of warps in block
int baseIdx = blockIdx.x * warpsInBlock + warpId;
int warpThreadIdx = threadIdx.x % warpSize; // Thread index within the warp
int warpsInGrid = gridDim.x * warpsInBlock;
int C_per_iteration = warpSize * x128::size;
int iterations_C = C / C_per_iteration;
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
// init shared memory to zero
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
unsigned int *tmp_flag = (unsigned int*)(shared + C*2);
__syncthreads();
for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) {
int b = idx / T;
int t = idx % T;
const floatX* dout_bt = dout + b * T * C + t * C;
const floatX* inp_bt = inp + b * T * C + t * C;
floatX* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warpThreadIdx * x128::size; i < C; i += warpSize * x128::size) {
x128 dout128_i = load128(dout_bt + i);
x128 inp128_i = load128(inp_bt + i);
x128 weight128_i = load128(weight + i);
for (int k = 0; k < x128::size; k++) {
float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
}
dnorm_mean = warpReduceSum(dnorm_mean) / C;
dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C;
// now iterate again and accumulate all the gradients
// unfortunately we cannot use the same index for x128 arrays and shared memory
// as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper)
// so this would result in an 8-way bank conflict, and kill performance
// so instead, we use a shared memory friendly index, and reorder before the final write
for (int i = 0; i < iterations_C; i++) {
int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);
int shared_index = warpThreadIdx + (i * C_per_iteration);
x128 dout128 = load128cs(dout_bt + global_index);
x128 inp128 = load128cs(inp_bt + global_index);
x128 dinp128 = load128(dinp_bt + global_index);
x128 weight128 = load128(weight + global_index);
for (int x = 0; x < x128::size; x++) {
float dout_i = (float)dout128[x];
float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight128[x] * dout_i;
// gradient contribution to bias (using shared memory friendly index)
atomicAdd(&dbias_shared[shared_index + x*warpSize], dout_i);
// gradient contribution to weight (using shared memory friendly index)
atomicAdd(&dweight_shared[shared_index + x*warpSize], norm_bti * dout_i);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp128[x] = (floatX)((float)dinp128[x] + dval);
}
// cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing
store128cg(dinp_bt + global_index, dinp128);
}
}
// Accumulate into a FP32 scratchpad
// BF16 atomics are potentially much slower... and this is more precise!
// todo - could potentially avoid the extra copy if floatX is FP32, fairly negligible though
__syncthreads();
float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C;
unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C));
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
// global atomics in the same "shared memory banking friendly" order
atomicAdd(&scratch_dbias[i], dbias_shared[i]);
atomicAdd(&scratch_dweight[i], dweight_shared[i]);
}
__syncthreads();
if (threadIdx.x == 0) {
*tmp_flag = atomicInc(scratchFlag, gridDim.x);
}
__syncthreads();
if (*tmp_flag == gridDim.x-1) {
for (int i = warpId; i < iterations_C; i += warpsInBlock) {
// reorder from atomic/shared memory-friendly index to real global memory index
// and convert from float/FP32 to floatX/BF16 for the final write
int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);
int shared_index = warpThreadIdx + (i * C_per_iteration);
x128 dbias128 = load128(dbias + global_index);
x128 dweight128 = load128(dweight + global_index);
for (int x = 0; x < x128::size; x++) {
float s_db = scratch_dbias[shared_index + x*warpSize];
float s_dw = scratch_dweight[shared_index + x*warpSize];
dbias128[x] = (floatX)(s_db + (float)dbias128[x]);
dweight128[x] = (floatX)(s_dw + (float)dweight128[x]);
}
store128(dbias + global_index, dbias128);
store128(dweight + global_index, dweight128);
}
}
}
__global__ void layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
const floatX* dout, const floatX* inp, const floatX* weight,
const floatX* mean, const floatX* rstd,
int B, int T, int C) {
if(C % (32 * x128::size) != 0) {
if(threadIdx.x == 0 && blockIdx.x == 0) {
printf("Number of channels is not a multiple of 32 * x128::size");
}
__trap(); // prefer to crash here than run into a deadlock later on
}
constexpr int WARP_SIZE = 32;
int BLOCK_SIZE = blockDim.x;
int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block
extern __shared__ float shared[]; // size = 2 * C + 1
int warpId = threadIdx.x / WARP_SIZE; // warp index within a block
int baseIdx = blockIdx.x * warpsInBlock + warpId;
int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp
int warpsInGrid = gridDim.x * warpsInBlock;
int C_per_iteration = WARP_SIZE * x128::size;
int iterations_C = ceil_div(C, C_per_iteration) + 2;
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
float* dbias_tmp_shared = shared + 2 * C;
float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE;
// init shared memory to zero
for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE);
__syncthreads();
for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) {
int b = idx / T;
int t = idx % T;
const floatX* dout_bt = dout + b * T * C + t * C;
const floatX* inp_bt = inp + b * T * C + t * C;
floatX* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) {
x128 dout128_i = load128(dout_bt + i);
x128 inp128_i = load128(inp_bt + i);
x128 weight128_i = load128(weight + i);
for (int k = 0; k < x128::size; k++) {
float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
}
dnorm_mean = warpReduceSum(dnorm_mean) / C;
dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C;
// now iterate again and accumulate all the gradients
// unfortunately we cannot use the same index for x128 arrays and shared memory
// as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper)
// so this would result in an 8-way bank conflict, and kill performance
// so instead, we use a shared memory friendly index, and reorder before the final write
for (int i = 0; i < iterations_C; i++) {
int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);
int shared_index = warpThreadIdx + (i * C_per_iteration);
if (global_index >= C) {
break;
}
x128 dout128 = load128cs(dout_bt + global_index);
x128 inp128 = load128cs(inp_bt + global_index);
x128 dinp128 = load128(dinp_bt + global_index);
x128 weight128 = load128(weight + global_index);
for (int x = 0; x < x128::size; x++) {
float dout_i = (float)dout128[x];
float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight128[x] * dout_i;
// sum up the gradients for bias and weight across the entire block
// this is basically a reduction (but only inter-warp, not intra-warp)
// doing it this way allows us to avoid using atomics while using many warps
if (warpId != 0) {
dbias_tmp_shared[threadIdx.x] = dout_i;
dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i;
}
__syncthreads();
if (warpId == 0) {
float dbias_tmp = dout_i;
float dweight_tmp = norm_bti * dout_i;
for (int j = 1; j < warpsInBlock; j++) {
dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE];
dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE];
}
// gradient contribution to bias (using shared memory friendly index)
dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp;
// gradient contribution to weight (using shared memory friendly index)
dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp;
}
__syncthreads();
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp128[x] = (floatX)((float)dinp128[x] + dval);
}
// cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing
store128cg(dinp_bt + global_index, dinp128);
}
}
__syncthreads();
// Each block writes its partial sum to global memory
// The last block to finish becomes responsible for summing up all the partial sums
// This is done by atomically incrementing a flag (cleared to 0 before launching the kernel)
unsigned int* scratchFlag = (unsigned int*)(scratch);
// Increment scratch pointer by a full cacheline so that everything remains cacheline aligned
scratch += 32;
float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C;
for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) {
// Write to global memory in the same "shared memory banking friendly" order
scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i];
scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i];
}
__syncthreads();