Skip to content

Commit

Permalink
new kernel that does a single pass over x on load, using a more cleve…
Browse files Browse the repository at this point in the history
…r variance formula. only very slightly faster on my A100 sadly
  • Loading branch information
karpathy committed Apr 19, 2024
1 parent 816254e commit cb791c4
Showing 1 changed file with 69 additions and 2 deletions.
71 changes: 69 additions & 2 deletions dev/cuda/layernorm_forward.cu
Expand Up @@ -12,6 +12,10 @@ version 2 parallelizes over all of B,T,C
version 3 uses cooperative groups to parallelize over all of B,T,C
./layernorm_forward 3
version 4 uses a more clever way to estimate variance, var(x) = mean(x**2) - mean(x)**2
(allowing us to do a single pass over x on load)
./layernorm_forward 4
*/

#include <stdio.h>
Expand Down Expand Up @@ -172,14 +176,13 @@ __global__ void normalization_kernel(float* out, const float* inp, float* mean,
out[idx] = o;
}

// ----------------------------------------------------------------------------

__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,
const float* __restrict__ inp, const float* __restrict__ weight,
const float* __restrict__ bias, int N, int C) {
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
// meta_group_size is the number of warps in a block, and meta_group_rank is the warp index
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= N) {
return;
Expand Down Expand Up @@ -222,6 +225,56 @@ __global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __rest
}
}

// same as kernel 3 but uses var(x) == mean(x**2) - mean(x)**2
__global__ void layernorm_forward_kernel4(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,
const float* __restrict__ inp, const float* __restrict__ weight,
const float* __restrict__ bias, int N, int C) {
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= N) {
return;
}

// the row of input that this group of threads is responsible for
const float* x = inp + idx * C;

// thread coarsening through the row, reduce the sum in series
float sum = 0.0; // stores sum(x)
float sum2 = 0.0; // stores sum(x**2)
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float xi = x[i];
sum += xi;
sum2 += xi * xi;
}
// warp-level reduction at the end
sum = cg::reduce(warp, sum, cg::plus<float>{}); // sum(x)
sum2 = cg::reduce(warp, sum2, cg::plus<float>{}); // sum(x**2)
sum /= C; // mean(x)
sum2 /= C; // mean(x**2)

// mean, var, rstd
float m = sum;
float var = sum2 - sum * sum;
float s = rsqrtf(var + 1e-5f);

// store the mean, no need to cache it
if(warp.thread_rank() == 0 && mean != nullptr) {
__stcs(mean + idx, m);
}
// store the rstd, no need to cache it
if(warp.thread_rank() == 0 && rstd != nullptr) {
__stcs(rstd + idx, s);
}
// final normalization and scaling by weight/bias
float* o = out + idx * C;
for (int c = warp.thread_rank(); c < C; c += warp.size()) {
float n = s * (__ldcs(x+c) - m);
__stcs(o+c, n * weight[c] + bias[c]);
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -263,6 +316,17 @@ void layernorm_forward3(float* out, float* mean, float* rstd,
cudaCheck(cudaGetLastError());
}

void layernorm_forward4(float* out, float* mean, float* rstd,
const float* inp, const float* weight, const float* bias,
int B, int T, int C,
const int block_size) {
assert(block_size % 32 == 0);
const int N = B * T;
const int grid_size = ceil_div(N * 32, block_size);
layernorm_forward_kernel4<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void layernorm_forward(int kernel_num,
float* out, float* mean, float* rstd,
Expand All @@ -279,6 +343,9 @@ void layernorm_forward(int kernel_num,
case 3:
layernorm_forward3(out, mean, rstd, inp, weight, bias, B, T, C, block_size);
break;
case 4:
layernorm_forward4(out, mean, rstd, inp, weight, bias, B, T, C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down

0 comments on commit cb791c4

Please sign in to comment.