From cb791c4ef58d45d58e5af624b0ed41439ac7aeff Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 19 Apr 2024 22:55:30 +0000 Subject: [PATCH] new kernel that does a single pass over x on load, using a more clever variance formula. only very slightly faster on my A100 sadly --- dev/cuda/layernorm_forward.cu | 71 ++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/dev/cuda/layernorm_forward.cu b/dev/cuda/layernorm_forward.cu index f40b4e27..1f2209a3 100644 --- a/dev/cuda/layernorm_forward.cu +++ b/dev/cuda/layernorm_forward.cu @@ -12,6 +12,10 @@ version 2 parallelizes over all of B,T,C version 3 uses cooperative groups to parallelize over all of B,T,C ./layernorm_forward 3 + +version 4 uses a more clever way to estimate variance, var(x) = mean(x**2) - mean(x)**2 + (allowing us to do a single pass over x on load) +./layernorm_forward 4 */ #include @@ -172,14 +176,13 @@ __global__ void normalization_kernel(float* out, const float* inp, float* mean, out[idx] = o; } -// ---------------------------------------------------------------------------- - __global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, const float* __restrict__ inp, const float* __restrict__ weight, const float* __restrict__ bias, int N, int C) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + // meta_group_size is the number of warps in a block, and meta_group_rank is the warp index int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if(idx >= N) { return; @@ -222,6 +225,56 @@ __global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __rest } } +// same as kernel 3 but uses var(x) == mean(x**2) - mean(x)**2 +__global__ void layernorm_forward_kernel4(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, + const float* __restrict__ inp, const float* __restrict__ weight, + const float* __restrict__ bias, int N, int C) { + namespace cg = cooperative_groups; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + if(idx >= N) { + return; + } + + // the row of input that this group of threads is responsible for + const float* x = inp + idx * C; + + // thread coarsening through the row, reduce the sum in series + float sum = 0.0; // stores sum(x) + float sum2 = 0.0; // stores sum(x**2) + for (int i = warp.thread_rank(); i < C; i += warp.size()) { + float xi = x[i]; + sum += xi; + sum2 += xi * xi; + } + // warp-level reduction at the end + sum = cg::reduce(warp, sum, cg::plus{}); // sum(x) + sum2 = cg::reduce(warp, sum2, cg::plus{}); // sum(x**2) + sum /= C; // mean(x) + sum2 /= C; // mean(x**2) + + // mean, var, rstd + float m = sum; + float var = sum2 - sum * sum; + float s = rsqrtf(var + 1e-5f); + + // store the mean, no need to cache it + if(warp.thread_rank() == 0 && mean != nullptr) { + __stcs(mean + idx, m); + } + // store the rstd, no need to cache it + if(warp.thread_rank() == 0 && rstd != nullptr) { + __stcs(rstd + idx, s); + } + // final normalization and scaling by weight/bias + float* o = out + idx * C; + for (int c = warp.thread_rank(); c < C; c += warp.size()) { + float n = s * (__ldcs(x+c) - m); + __stcs(o+c, n * weight[c] + bias[c]); + } +} + // ---------------------------------------------------------------------------- // kernel launcher @@ -263,6 +316,17 @@ void layernorm_forward3(float* out, float* mean, float* rstd, cudaCheck(cudaGetLastError()); } +void layernorm_forward4(float* out, float* mean, float* rstd, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, + const int block_size) { + assert(block_size % 32 == 0); + const int N = B * T; + const int grid_size = ceil_div(N * 32, block_size); + layernorm_forward_kernel4<<>>(out, mean, rstd, inp, weight, bias, N, C); + cudaCheck(cudaGetLastError()); +} + // kernel version dispatch void layernorm_forward(int kernel_num, float* out, float* mean, float* rstd, @@ -279,6 +343,9 @@ void layernorm_forward(int kernel_num, case 3: layernorm_forward3(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; + case 4: + layernorm_forward4(out, mean, rstd, inp, weight, bias, B, T, C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1);