Skip to content

Commit

Permalink
add one more kernel, allocating a block per row. bad idea if C is too…
Browse files Browse the repository at this point in the history
… low, as we have it right now
  • Loading branch information
karpathy committed Apr 19, 2024
1 parent cb791c4 commit 49d41ae
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions dev/cuda/layernorm_forward.cu
Expand Up @@ -16,6 +16,9 @@ version 3 uses cooperative groups to parallelize over all of B,T,C
version 4 uses a more clever way to estimate variance, var(x) = mean(x**2) - mean(x)**2
(allowing us to do a single pass over x on load)
./layernorm_forward 4
verstion 5 allocates blocks per row instead of warps per row, same alg as 4 otherwise
./layernorm_forward 5
*/

#include <stdio.h>
Expand Down Expand Up @@ -275,6 +278,65 @@ __global__ void layernorm_forward_kernel4(float* __restrict__ out, float* __rest
}
}

// like 4, but in kernel 5 we have each block doing one row, not just a single warp
__global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,
const float* __restrict__ inp, const float* __restrict__ weight,
const float* __restrict__ bias, int N, int C) {
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
__shared__ float shared_sum[32]; // block_size max is 1024 = 32 * 32 warps
__shared__ float shared_sum2[32]; // warps will be writing into shared memeory after warp-reduce
int num_warps = blockDim.x / 32;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
int idx = blockIdx.x; // simpoy one block per row
// the row of input that this group of threads is responsible for
const float* x = inp + idx * C;
// thread coarsening through the row, reduce the sum in series
float thread_sum = 0.0; // stores sum(x)
float thread_sum2 = 0.0; // stores sum(x**2)
// for (int i = C + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {
for (int i = threadIdx.x; i < C; i += blockDim.x) {
float xi = x[i];
thread_sum += xi;
thread_sum2 += xi * xi;
}
// warp-level reduction
float warp_sum = cg::reduce(warp, thread_sum, cg::plus<float>{}); // sum(x)
float warp_sum2 = cg::reduce(warp, thread_sum2, cg::plus<float>{}); // sum(x**2)
// store the warp-level reduction in shared memory (we could have lane_id == 0 guard but not needed)
shared_sum[warp_id] = warp_sum;
shared_sum2[warp_id] = warp_sum2;
__syncthreads();
// load results from shared memory to threads, pad with zeros for threads that are out of bounds
warp_sum = (lane_id < num_warps) ? shared_sum[lane_id] : 0.0f;
warp_sum2 = (lane_id < num_warps) ? shared_sum2[lane_id] : 0.0f;
// now reduce the warp-level reductions
float block_sum = cg::reduce(warp, warp_sum, cg::plus<float>{}); // sum(x)
float block_sum2 = cg::reduce(warp, warp_sum2, cg::plus<float>{}); // sum(x**2)
// mean, var, rstd
block_sum /= C; // mean(x)
block_sum2 /= C; // mean(x**2)
float m = block_sum;
float var = block_sum2 - m * m;
float s = rsqrtf(var + 1e-5f);
// store the mean, no need to cache it
if(warp.thread_rank() == 0 && mean != nullptr) {
__stcs(mean + idx, m);
}
// store the rstd, no need to cache it
if(warp.thread_rank() == 0 && rstd != nullptr) {
__stcs(rstd + idx, s);
}
// final normalization and scaling by weight/bias
float* o = out + idx * C;
for (int c = warp.thread_rank(); c < C; c += warp.size()) {
float n = s * (__ldcs(x+c) - m);
__stcs(o+c, n * weight[c] + bias[c]);
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -327,6 +389,17 @@ void layernorm_forward4(float* out, float* mean, float* rstd,
cudaCheck(cudaGetLastError());
}

void layernorm_forward5(float* out, float* mean, float* rstd,
const float* inp, const float* weight, const float* bias,
int B, int T, int C,
const int block_size) {
assert(block_size % 32 == 0);
const int N = B * T;
const int grid_size = N;
layernorm_forward_kernel5<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void layernorm_forward(int kernel_num,
float* out, float* mean, float* rstd,
Expand All @@ -346,6 +419,9 @@ void layernorm_forward(int kernel_num,
case 4:
layernorm_forward4(out, mean, rstd, inp, weight, bias, B, T, C, block_size);
break;
case 5:
layernorm_forward5(out, mean, rstd, inp, weight, bias, B, T, C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down

0 comments on commit 49d41ae

Please sign in to comment.