From b37371fd04b45ea606d5f28082b9991c3bde0f56 Mon Sep 17 00:00:00 2001 From: Christopher Date: Sun, 2 Jun 2024 00:52:21 +0000 Subject: [PATCH 1/2] Added packed layernorm_forward --- train_gpt2.cu | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index bf6acf1fe..bc14dcd80 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -663,9 +663,13 @@ __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __re // mean float sum = 0.0f; - for (int i = lane_id; i < C; i += WARP_SIZE) { - sum += (float)x[i]; + for (int i = lane_id * x128::size; i < C; i += WARP_SIZE * x128::size ) { + x128 inp_packed = load128(x + i); + for (int k = 0; k < x128::size; ++k) { + sum += (float)inp_packed[k]; + } } + sum = warpReduceSum(sum); float m = sum / C; if(lane_id == 0 && mean != nullptr) { @@ -674,27 +678,40 @@ __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __re // rstd sum = 0.0f; - for (int i = lane_id; i < C; i += WARP_SIZE) { - float diff = (float)x[i] - m; - sum += diff * diff; + for (int i = lane_id * x128::size; i < C; i += WARP_SIZE * x128::size) { + x128 inp_packed = load128(x + i); + for (int k = 0; k < x128::size; ++k) { + float diff = (float)inp_packed[k] - m; + sum += diff * diff; + } } + sum = warpReduceSum(sum); float s = rsqrtf(sum / C + 1e-5f); if(lane_id == 0 && rstd != nullptr) { __stcs(rstd + idx, (floatX)s); } - // final normalization and scaling by weight/bias + floatX* o = out + idx * C; - for (int c = lane_id; c < C; c += WARP_SIZE) { - // load and store using the .cs "streaming" hint to the compiler, - // indicating that this data will not be reused soon, and can be streamed through the caches - // this allows the threads to get more cache-hits for the (shared) weight and bias parameters - float n = s * ((float)__ldcs(x+c) - m); - __stcs(o+c, (floatX)(n * (float)weight[c] + (float)bias[c])); + for (int c = lane_id * x128::size; c < C; c += WARP_SIZE * x128::size) { + // Load data into packed format + x128 inp_packed = load128cs(x + c); + x128 weight_packed = load128cs(weight + c); + x128 bias_packed = load128cs(bias + c); + x128 out_packed; + + for (int k = 0; k < x128::size; ++k) { + float n = s * ((float)inp_packed[k] - m); + out_packed[k] = (floatX)(n * (float)weight_packed[k] + (float)bias_packed[k]); + } + + // Store packed data back + store128(o + c, out_packed); } } + __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, From cc8a18d20e1e1b9a6357768b3e142f8e57a7bb3f Mon Sep 17 00:00:00 2001 From: Christopher Date: Mon, 3 Jun 2024 03:44:46 +0000 Subject: [PATCH 2/2] Added the laternorm forward dev kernels for the packing changes --- dev/cuda/layernorm_forward.cu | 513 +++++++++++++++++++++++++++------- 1 file changed, 418 insertions(+), 95 deletions(-) diff --git a/dev/cuda/layernorm_forward.cu b/dev/cuda/layernorm_forward.cu index 3e948289a..846914b16 100644 --- a/dev/cuda/layernorm_forward.cu +++ b/dev/cuda/layernorm_forward.cu @@ -2,7 +2,7 @@ Kernels for layernorm forward pass. Compile example: -nvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_forward.cu -o layernorm_forward +nvcc -O3 --use_fast_math layernorm_forward.cu -o layernorm_forward version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C ./layernorm_forward 1 @@ -19,6 +19,19 @@ version 4 uses a more clever way to estimate variance, var(x) = mean(x**2) - mea verstion 5 allocates blocks per row instead of warps per row, same alg as 4 otherwise ./layernorm_forward 5 + +verstion 6 removes cooperative groups +./layernorm_forward 6 + +verstion 7 adds data type packing to the output +./layernorm_forward 7 + +verstion 8 adds packing to the mean and rstd +./layernorm_forward 8 + +verstion 9 shows the performance impact of splitting the mean and rstd +./layernorm_forward 8 + */ #include @@ -29,6 +42,8 @@ verstion 5 allocates blocks per row instead of warps per row, same alg as 4 othe #include #include "common.h" +#define WARP_SIZE 32U + // ---------------------------------------------------------------------------- // CPU code reference @@ -74,52 +89,52 @@ void layernorm_forward_cpu(float* out, float* mean, float* rstd, // GPU kernels // naive drag and drop implementation into kernel, parallelize over B,T, loop over C -__global__ void layernorm_forward_kernel1(float* out, float* mean, float* rstd, - const float* inp, const float* weight, const float* bias, +__global__ void layernorm_forward_kernel1(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, int N, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; float eps = 1e-5f; if (idx < N) { // seek to the input position inp[idx,:] - const float* x = inp + idx * C; + const floatX* x = inp + idx * C; // calculate the mean float m = 0.0f; for (int i = 0; i < C; i++) { - m += x[i]; + m += (float)x[i]; } m = m / C; // calculate the variance (without any bias correction) float v = 0.0f; for (int i = 0; i < C; i++) { - float xshift = x[i] - m; + float xshift = (float)x[i] - m; v += xshift * xshift; } v = v / C; // calculate the rstd float s = 1.0f / sqrtf(v + eps); // seek to the output position in out[idx,:] - float* out_idx = out + idx * C; + floatX* out_idx = (floatX*)out + idx * C; for (int i = 0; i < C; i++) { - float n = (s * (x[i] - m)); // normalized output - float o = n * weight[i] + bias[i]; // scale and shift it - out_idx[i] = o; // write + float n = (s * ((float)x[i] - m)); // normalized output + float o = n * (float)weight[i] + (float)bias[i]; // scale and shift it + out_idx[i] = (floatX)o; // write } // cache the mean and rstd for the backward pass later - mean[idx] = m; - rstd[idx] = s; + mean[idx] = (floatX)m; + rstd[idx] = (floatX)s; } } -__global__ void mean_kernel(float* mean, const float* inp, int N, int C, int block_size) { +__global__ void mean_kernel(floatX* mean, const floatX* inp, int N, int C, int block_size) { extern __shared__ float shared[]; int idx = blockIdx.x; // range [0, B*T) int tid = threadIdx.x; // range [0, block_size) - const float* x = inp + idx * C; + const floatX* x = inp + idx * C; // thread coarsening float sum = 0.0f; for (int i = tid; i < C; i += block_size) { - sum += x[i]; + sum += (float)x[i]; } shared[tid] = sum; __syncthreads(); @@ -132,20 +147,20 @@ __global__ void mean_kernel(float* mean, const float* inp, int N, int C, int blo } // write the final result (at thread 0) to global memory if (tid == 0) { - mean[idx] = shared[0] / C; + mean[idx] = (floatX)(shared[0] / C); } } -__global__ void rstd_kernel(float* rstd, const float* inp, const float* mean, int N, int C, int block_size) { +__global__ void rstd_kernel(floatX* rstd, const floatX* inp, const floatX* mean, int N, int C, int block_size) { extern __shared__ float shared[]; int idx = blockIdx.x; // range [0, B*T) int tid = threadIdx.x; // range [0, block_size) - const float* x = inp + idx * C; - float m = mean[idx]; + const floatX* x = inp + idx * C; + float m = (float)mean[idx]; // thread coarsening float sum = 0.0f; for (int i = tid; i < C; i += block_size) { - float diff = x[i] - m; + float diff = (float)x[i] - m; sum += diff * diff; } shared[tid] = sum; @@ -159,29 +174,29 @@ __global__ void rstd_kernel(float* rstd, const float* inp, const float* mean, in } // write the final result (at thread 0) to global memory if (tid == 0) { - rstd[idx] = 1.0f / sqrtf(shared[0] / C + 1e-5f); + rstd[idx] = (floatX)(1.0f / sqrtf(shared[0] / C + 1e-5f)); } } -__global__ void normalization_kernel(float* out, const float* inp, float* mean, float* rstd, - const float* weight, const float* bias, int B, int T, int C) { +__global__ void normalization_kernel(floatX* out, const floatX* inp, floatX* mean, floatX* rstd, + const floatX* weight, const floatX* bias, int B, int T, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int bt = idx / C; int c = idx % C; - float m = mean[bt]; - float s = rstd[bt]; - float xi = inp[idx]; + float m = (float)mean[bt]; + float s = (float)rstd[bt]; + float xi = (float)inp[idx]; float n = s * (xi - m); - float o = n * weight[c] + bias[c]; + float o = n * (float)weight[c] + (float)bias[c]; - out[idx] = o; + out[idx] = (floatX)o; } -__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, - const float* __restrict__ inp, const float* __restrict__ weight, - const float* __restrict__ bias, int N, int C) { +__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); @@ -192,46 +207,46 @@ __global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __rest } // the row of input that this group of threads is responsible for - const float* x = inp + idx * C; + const floatX* x = inp + idx * C; // mean float sum = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { - sum += x[i]; + sum += (float)x[i]; } sum = cg::reduce(warp, sum, cg::plus{}); float m = sum / C; if(warp.thread_rank() == 0 && mean != nullptr) { - __stcs(mean + idx, m); + __stcs(mean + idx, (floatX)m); } // rstd sum = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { - float diff = x[i] - m; + float diff = (float)x[i] - m; sum += diff * diff; } sum = cg::reduce(warp, sum, cg::plus{}); float s = rsqrtf(sum / C + 1e-5f); if(warp.thread_rank() == 0 && rstd != nullptr) { - __stcs(rstd + idx, s); + __stcs(rstd + idx, (floatX)s); } // final normalization and scaling by weight/bias - float* o = out + idx * C; + floatX* o = out + idx * C; for (int c = warp.thread_rank(); c < C; c += warp.size()) { // load and store using the .cs "streaming" hint to the compiler, // indicating that this data will not be reused soon, and can be streamed through the caches // this allows the threads to get more cache-hits for the (shared) weight and bias parameters - float n = s * (__ldcs(x+c) - m); - __stcs(o+c, n * weight[c] + bias[c]); + float n = s * ((float)__ldcs(x+c) - m); + __stcs(o+c, n * (float)weight[c] + (float)bias[c]); } } // same as kernel 3 but uses var(x) == mean(x**2) - mean(x)**2 -__global__ void layernorm_forward_kernel4(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, - const float* __restrict__ inp, const float* __restrict__ weight, - const float* __restrict__ bias, int N, int C) { +__global__ void layernorm_forward_kernel4(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); @@ -241,13 +256,13 @@ __global__ void layernorm_forward_kernel4(float* __restrict__ out, float* __rest } // the row of input that this group of threads is responsible for - const float* x = inp + idx * C; + const floatX* x = inp + idx * C; // thread coarsening through the row, reduce the sum in series float sum = 0.0; // stores sum(x) float sum2 = 0.0; // stores sum(x**2) for (int i = warp.thread_rank(); i < C; i += warp.size()) { - float xi = x[i]; + float xi = (float)x[i]; sum += xi; sum2 += xi * xi; } @@ -264,24 +279,24 @@ __global__ void layernorm_forward_kernel4(float* __restrict__ out, float* __rest // store the mean, no need to cache it if(warp.thread_rank() == 0 && mean != nullptr) { - __stcs(mean + idx, m); + __stcs(mean + idx, (floatX)m); } // store the rstd, no need to cache it if(warp.thread_rank() == 0 && rstd != nullptr) { - __stcs(rstd + idx, s); + __stcs(rstd + idx, (floatX)s); } // final normalization and scaling by weight/bias - float* o = out + idx * C; + floatX* o = out + idx * C; for (int c = warp.thread_rank(); c < C; c += warp.size()) { - float n = s * (__ldcs(x+c) - m); - __stcs(o+c, n * weight[c] + bias[c]); + float n = s * ((float)__ldcs(x+c) - m); + __stcs(o+c, n * (float)weight[c] + (float)bias[c]); } } // like 4, but in kernel 5 we have each block doing one row, not just a single warp -__global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, - const float* __restrict__ inp, const float* __restrict__ weight, - const float* __restrict__ bias, int N, int C) { +__global__ void layernorm_forward_kernel5(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); @@ -292,13 +307,13 @@ __global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __rest int lane_id = threadIdx.x % 32; int idx = blockIdx.x; // simpoy one block per row // the row of input that this group of threads is responsible for - const float* x = inp + idx * C; + const floatX* x = inp + idx * C; // thread coarsening through the row, reduce the sum in series float thread_sum = 0.0; // stores sum(x) float thread_sum2 = 0.0; // stores sum(x**2) // for (int i = C + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { for (int i = threadIdx.x; i < C; i += blockDim.x) { - float xi = x[i]; + float xi = (float)x[i]; thread_sum += xi; thread_sum2 += xi * xi; } @@ -323,25 +338,280 @@ __global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __rest float s = rsqrtf(var + 1e-5f); // store the mean, no need to cache it if(threadIdx.x == 0 && mean != nullptr) { - __stcs(mean + idx, m); + __stcs(mean + idx, (floatX)m); } // store the rstd, no need to cache it if(threadIdx.x == 0 && rstd != nullptr) { - __stcs(rstd + idx, s); + __stcs(rstd + idx, (floatX)s); } // final normalization and scaling by weight/bias - float* o = out + idx * C; + floatX* o = out + idx * C; for (int i = threadIdx.x; i < C; i += blockDim.x) { - float n = s * (__ldcs(x+i) - m); - __stcs(o+i, n * weight[i] + bias[i]); + float n = s * ((float)__ldcs(x+i) - m); + __stcs(o+i, n * (float)weight[i] + (float)bias[i]); + } +} + +__global__ void layernorm_forward_kernel6(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = blockDim.x / WARP_SIZE; + + int idx = blockIdx.x * num_warps + warp_id; + if(idx >= N) { return; } // guard + + // the row of input that this group of threads is responsible for + const floatX* x = inp + idx * C; + + // mean + float sum = 0.0f; + for (int i = lane_id; i < C; i += WARP_SIZE) { + sum += (float)x[i]; + } + sum = warpReduceSum(sum); + float m = sum / C; + if(lane_id == 0 && mean != nullptr) { + __stcs(mean + idx, (floatX)m); + } + + // rstd + sum = 0.0f; + for (int i = lane_id; i < C; i += WARP_SIZE) { + float diff = (float)x[i] - m; + sum += diff * diff; + } + sum = warpReduceSum(sum); + float s = rsqrtf(sum / C + 1e-5f); + if(lane_id == 0 && rstd != nullptr) { + __stcs(rstd + idx, (floatX)s); + } + + // final normalization and scaling by weight/bias + floatX* o = out + idx * C; + for (int c = lane_id; c < C; c += WARP_SIZE) { + // load and store using the .cs "streaming" hint to the compiler, + // indicating that this data will not be reused soon, and can be streamed through the caches + // this allows the threads to get more cache-hits for the (shared) weight and bias parameters + float n = s * ((float)__ldcs(x+c) - m); + __stcs(o+c, (floatX)(n * (float)weight[c] + (float)bias[c])); + } +} + + +__global__ void layernorm_forward_kernel7(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = blockDim.x / WARP_SIZE; + + int idx = blockIdx.x * num_warps + warp_id; + if(idx >= N) { return; } // guard + + // the row of input that this group of threads is responsible for + const floatX* x = inp + idx * C; + + // mean + float sum = 0.0f; + for (int i = lane_id; i < C; i += WARP_SIZE) { + sum += (float)x[i]; + } + sum = warpReduceSum(sum); + float m = sum / C; + if(lane_id == 0 && mean != nullptr) { + __stcs(mean + idx, (floatX)m); + } + + // rstd + sum = 0.0f; + for (int i = lane_id; i < C; i += WARP_SIZE) { + float diff = (float)x[i] - m; + sum += diff * diff; + } + sum = warpReduceSum(sum); + float s = rsqrtf(sum / C + 1e-5f); + if(lane_id == 0 && rstd != nullptr) { + __stcs(rstd + idx, (floatX)s); + } + + floatX* o = out + idx * C; + int c = lane_id * x128::size; + for (; c < C; c += WARP_SIZE * x128::size) { + // Load data into packed format + x128 inp_packed = load128(x + c); + x128 weight_packed = load128(weight + c); + x128 bias_packed = load128(bias + c); + + x128 out_packed; + + for (int k = 0; k < x128::size; ++k) { + float n = s * ((float)inp_packed[k] - m); + out_packed[k] = (floatX)(n * (float)weight_packed[k] + (float)bias_packed[k]); + } + + // Store packed data back + store128(o + c, out_packed); + } +} + +__global__ void layernorm_forward_kernel8(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = blockDim.x / WARP_SIZE; + + int idx = blockIdx.x * num_warps + warp_id; + if(idx >= N) { return; } // guard + + // the row of input that this group of threads is responsible for + const floatX* x = inp + idx * C; + + // mean + float sum = 0.0f; + int i = lane_id * x128::size; + for (; i < C; i += WARP_SIZE * x128::size ) { + x128 inp_packed = load128(x + i); + for (int k = 0; k < x128::size; ++k) { + sum += (float)inp_packed[k]; + } + } + + + sum = warpReduceSum(sum); + float m = sum / C; + if(lane_id == 0 && mean != nullptr) { + __stcs(mean + idx, (floatX)m); + } + + // rstd + sum = 0.0f; + i = lane_id * x128::size; + for (; i < C; i += WARP_SIZE * x128::size) { + x128 inp_packed = load128(x + i); + for (int k = 0; k < x128::size; ++k) { + float diff = (float)inp_packed[k] - m; + sum += diff * diff; + } + } + + sum = warpReduceSum(sum); + float s = rsqrtf(sum / C + 1e-5f); + if(lane_id == 0 && rstd != nullptr) { + __stcs(rstd + idx, (floatX)s); + } + + + floatX* o = out + idx * C; + int c = lane_id * x128::size; + for (; c < C; c += WARP_SIZE * x128::size) { + // Load data into packed format + x128 inp_packed = load128(x + c); + x128 weight_packed = load128(weight + c); + x128 bias_packed = load128(bias + c); + + x128 out_packed; + + for (int k = 0; k < x128::size; ++k) { + float n = s * ((float)inp_packed[k] - m); + out_packed[k] = (floatX)(n * (float)weight_packed[k] + (float)bias_packed[k]); + } + + // Store packed data back + store128(o + c, out_packed); + } +} + + +__global__ void layernorm_forward_kernel9(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = blockDim.x / WARP_SIZE; + + int idx = blockIdx.x * num_warps + warp_id; + if(idx >= N) { return; } // guard + + // the row of input that this group of threads is responsible for + const floatX* x = inp + idx * C; + + float m = __ldcs(mean + idx); + float s = __ldcs(rstd + idx); + + floatX* o = out + idx * C; + int c = lane_id * x128::size; + for (; c < C; c += WARP_SIZE * x128::size) { + // Load data into packed format + x128 inp_packed = load128cs(x + c); + x128 weight_packed = load128cs(weight + c); + x128 bias_packed = load128cs(bias + c); + x128 out_packed; + + #pragma unroll + for (int k = 0; k < x128::size; ++k) { + float n = s * ((float)inp_packed[k] - m); + out_packed[k] = (floatX)(n * (float)weight_packed[k] + (float)bias_packed[k]); + } + + // Store packed data back + store128(o + c, out_packed); + } +} + + +__global__ void mean_rstd(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = blockDim.x / WARP_SIZE; + + int idx = blockIdx.x * num_warps + warp_id; + if(idx >= N) { return; } // guard + + // the row of input that this group of threads is responsible for + const floatX* x = inp + idx * C; + + // mean + float sum = 0.0f; + for (int i = lane_id * x128::size; i < C; i += WARP_SIZE * x128::size ) { + x128 inp_packed = load128(x + i); + for (int k = 0; k < x128::size; ++k) { + sum += (float)inp_packed[k]; + } + } + + sum = warpReduceSum(sum); + float m = sum / C; + if(lane_id == 0 && mean != nullptr) { + mean[idx] = (floatX)m; + } + + // rstd + sum = 0.0f; + for (int i = lane_id * x128::size; i < C; i += WARP_SIZE * x128::size) { + x128 inp_packed = load128(x + i); + for (int k = 0; k < x128::size; ++k) { + float diff = (float)inp_packed[k] - m; + sum += diff * diff; + } + } + + sum = warpReduceSum(sum); + if(lane_id == 0 && rstd != nullptr) { + float s = rsqrtf(sum / C + 1e-5f); + rstd[idx] = (floatX)s; } } // ---------------------------------------------------------------------------- // kernel launcher -void layernorm_forward1(float* out, float* mean, float* rstd, - const float* inp, const float* weight, const float* bias, +void layernorm_forward1(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, int B, int T, int C, const int block_size) { const int N = B * T; @@ -350,8 +620,8 @@ void layernorm_forward1(float* out, float* mean, float* rstd, cudaCheck(cudaGetLastError()); } -void layernorm_forward2(float* out, float* mean, float* rstd, - const float* inp, const float* weight, const float* bias, +void layernorm_forward2(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, int B, int T, int C, const int block_size) { int N = B * T; @@ -367,8 +637,8 @@ void layernorm_forward2(float* out, float* mean, float* rstd, cudaCheck(cudaGetLastError()); } -void layernorm_forward3(float* out, float* mean, float* rstd, - const float* inp, const float* weight, const float* bias, +void layernorm_forward3(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, int B, int T, int C, const int block_size) { assert(block_size % 32 == 0); @@ -378,8 +648,8 @@ void layernorm_forward3(float* out, float* mean, float* rstd, cudaCheck(cudaGetLastError()); } -void layernorm_forward4(float* out, float* mean, float* rstd, - const float* inp, const float* weight, const float* bias, +void layernorm_forward4(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, int B, int T, int C, const int block_size) { assert(block_size % 32 == 0); @@ -389,8 +659,8 @@ void layernorm_forward4(float* out, float* mean, float* rstd, cudaCheck(cudaGetLastError()); } -void layernorm_forward5(float* out, float* mean, float* rstd, - const float* inp, const float* weight, const float* bias, +void layernorm_forward5(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, int B, int T, int C, const int block_size) { assert(block_size % 32 == 0); @@ -400,10 +670,56 @@ void layernorm_forward5(float* out, float* mean, float* rstd, cudaCheck(cudaGetLastError()); } + +void layernorm_forward6(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, + int B, int T, int C, + const int block_size) { + assert(block_size % 32 == 0); + const int N = B * T; + const int grid_size = ceil_div((int)(N*WARP_SIZE), block_size); + layernorm_forward_kernel6<<>>(out, mean, rstd, inp, weight, bias, N, C); + cudaCheck(cudaGetLastError()); +} + +void layernorm_forward7(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, + int B, int T, int C, + const int block_size) { + assert(block_size % 32 == 0); + const int N = B * T; + const int grid_size = N; + layernorm_forward_kernel7<<>>(out, mean, rstd, inp, weight, bias, N, C); + cudaCheck(cudaGetLastError()); +} + +void layernorm_forward8(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, + int B, int T, int C, + const int block_size) { + assert(block_size % 32 == 0); + const int N = B * T; + const int grid_size = ceil_div((int)(N*WARP_SIZE), block_size); + layernorm_forward_kernel8<<>>(out, mean, rstd, inp, weight, bias, N, C); + cudaCheck(cudaGetLastError()); +} + +void layernorm_forward9(floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, + int B, int T, int C, + const int block_size) { + assert(block_size % 32 == 0); + const int N = B * T; + const int grid_size = ceil_div((int)(N*WARP_SIZE), block_size); + mean_rstd<<>>(out, mean, rstd, inp, weight, bias, N, C); + layernorm_forward_kernel9<<>>(out, mean, rstd, inp, weight, bias, N, C); + cudaCheck(cudaGetLastError()); +} + // kernel version dispatch void layernorm_forward(int kernel_num, - float* out, float* mean, float* rstd, - const float* inp, const float* weight, const float* bias, + floatX* out, floatX* mean, floatX* rstd, + const floatX* inp, const floatX* weight, const floatX* bias, int B, int T, int C, const int block_size) { switch (kernel_num) { @@ -422,6 +738,18 @@ void layernorm_forward(int kernel_num, case 5: layernorm_forward5(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; + case 6: + layernorm_forward6(out, mean, rstd, inp, weight, bias, B, T, C, block_size); + break; + case 7: + layernorm_forward7(out, mean, rstd, inp, weight, bias, B, T, C, block_size); + break; + case 8: + layernorm_forward8(out, mean, rstd, inp, weight, bias, B, T, C, block_size); + break; + case 9: + layernorm_forward9(out, mean, rstd, inp, weight, bias, B, T, C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -431,39 +759,37 @@ void layernorm_forward(int kernel_num, // ---------------------------------------------------------------------------- int main(int argc, char **argv) { - srand(0); + setup_main(); int B = 8; int T = 1024; int C = 768; - int deviceIdx = 0; - cudaCheck(cudaSetDevice(deviceIdx)); - // create host memory of random numbers float* out = (float*)malloc(B * T * C * sizeof(float)); float* mean = (float*)malloc(B * T * sizeof(float)); float* rstd = (float*)malloc(B * T * sizeof(float)); float* inp = make_random_float(B * T * C); float* weight = make_random_float(C); - float* bias = make_random_float(C); - + float* bias = make_random_float(C); + // move to GPU - float* d_out; - float* d_mean; - float* d_rstd; - float* d_inp; - float* d_weight; - float* d_bias; - cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float))); - cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(float))); - cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(float))); - cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float))); - cudaCheck(cudaMalloc(&d_weight, C * sizeof(float))); - cudaCheck(cudaMalloc(&d_bias, C * sizeof(float))); - cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); - cudaCheck(cudaMemcpy(d_weight, weight, C * sizeof(float), cudaMemcpyHostToDevice)); - cudaCheck(cudaMemcpy(d_bias, bias, C * sizeof(float), cudaMemcpyHostToDevice)); + floatX* d_out; + floatX* d_mean; + floatX* d_rstd; + floatX* d_inp; + floatX* d_weight; + floatX* d_bias; + + cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_weight, C * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_bias, C * sizeof(floatX))); + cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); + cudaCheck(memcpy_convert(d_weight, weight, C)); + cudaCheck(memcpy_convert(d_bias, bias, C)); // read kernel_num from command line int kernel_num = 2; @@ -473,9 +799,6 @@ int main(int argc, char **argv) { printf("Using kernel %d\n", kernel_num); int block_sizes[] = {32, 64, 128, 256, 512, 1024}; - float* out_gpu = (float*)malloc(B * T * C * sizeof(float)); - float* mean_gpu = (float*)malloc(B * T * sizeof(float)); - float* rstd_gpu = (float*)malloc(B * T * sizeof(float)); layernorm_forward_cpu(out, mean, rstd, inp, weight, bias, B, T, C);