Skip to content

Commit

Permalink
make the new cooperative groups layernorm kernel the default. shaves …
Browse files Browse the repository at this point in the history
…off aonly bout 1ms of the total running time though
  • Loading branch information
karpathy committed Apr 11, 2024
1 parent 6f5ec06 commit 2c81198
Showing 1 changed file with 43 additions and 67 deletions.
110 changes: 43 additions & 67 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ GPT-2 Transformer Neural Net trained in raw CUDA
#include <time.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

// ----------------------------------------------------------------------------
// CUDA utils
Expand Down Expand Up @@ -66,72 +69,52 @@ __global__ void encoder_forward_kernel2(float* out,
}


__global__ void mean_kernel(float* mean, float* inp, int N, int C, int block_size) {
extern __shared__ float shared[];
int idx = blockIdx.x; // range [0, B*T)
int tid = threadIdx.x; // range [0, block_size)
float* x = inp + idx * C;
// thread coarsening
__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,
const float* __restrict__ inp, const float* __restrict__ weight,
const float* __restrict__ bias, int N, int C) {
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= N) {
return;
}

// the row of input that this group of threads is responsible for
const float* x = inp + idx * C;

// mean
float sum = 0.0f;
for (int i = tid; i < C; i += block_size) {
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
sum += x[i];
}
shared[tid] = sum;
__syncthreads();
// reductions
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
__syncthreads();
if (tid < stride) {
shared[tid] += shared[tid + stride];
}
}
// write the final result (at thread 0) to global memory
if (tid == 0) {
mean[idx] = shared[0] / C;
sum = cg::reduce(warp, sum, cg::plus<float>{});
float m = sum / C;
if(warp.thread_rank() == 0 && mean != nullptr) {
__stcs(mean + idx, m);
}
}

__global__ void rstd_kernel(float* rstd, float* inp, float* mean, int N, int C, int block_size) {
extern __shared__ float shared[];
int idx = blockIdx.x; // range [0, B*T)
int tid = threadIdx.x; // range [0, block_size)
float* x = inp + idx * C;
float m = mean[idx];
// thread coarsening
float sum = 0.0f;
for (int i = tid; i < C; i += block_size) {
// rstd
sum = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float diff = x[i] - m;
sum += diff * diff;
}
shared[tid] = sum;
__syncthreads();
// reductions
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
__syncthreads();
if (tid < stride) {
shared[tid] += shared[tid + stride];
}
}
// write the final result (at thread 0) to global memory
if (tid == 0) {
rstd[idx] = 1.0f / sqrtf(shared[0] / C + 1e-5f);
sum = cg::reduce(warp, sum, cg::plus<float>{});
float s = rsqrtf(sum / C + 1e-5f);
if(warp.thread_rank() == 0 && rstd != nullptr) {
__stcs(rstd + idx, s);
}
}

__global__ void normalization_kernel(float* out, float* inp, float* mean, float* rstd,
float* weight, float* bias, int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;

int bt = idx / C;
int c = idx % C;

float m = mean[bt];
float s = rstd[bt];
float xi = inp[idx];
float n = s * (xi - m);
float o = n * weight[c] + bias[c];

out[idx] = o;
// final normalization and scaling by weight/bias
float* o = out + idx * C;
for (int c = warp.thread_rank(); c < C; c += warp.size()) {
// load and store using the .cs "streaming" hint to the compiler,
// indicating that this data will not be reused soon, and can be streamed through the caches
// this allows the threads to get more cache-hits for the (shared) weight and bias parameters
float n = s * (__ldcs(x+c) - m);
__stcs(o+c, n * weight[c] + bias[c]);
}
}

__global__ void add_bias(float* out, float* bias, int B, int T, int OC) {
Expand Down Expand Up @@ -344,17 +327,10 @@ void encoder_forward(float* out,
void layernorm_forward(float* out, float* mean, float* rstd,
float* inp, float* weight, float* bias,
int B, int T, int C) {
int N = B * T;
const int block_size = 128;
// in mean and rstd, threads cooperate within blocks via reductions
mean_kernel<<<B * T, block_size, block_size * sizeof(float)>>>(mean, inp, N, C, block_size);
cudaCheck(cudaGetLastError());
rstd_kernel<<<B * T, block_size, block_size * sizeof(float)>>>(rstd, inp, mean, N, C, block_size);
cudaCheck(cudaGetLastError());
// in the normalization, everything just gets flattened out
const int block_size2 = 256;
const int grid_size = CEIL_DIV(B * T * C, block_size2);
normalization_kernel<<<grid_size, block_size2>>>(out, inp, mean, rstd, weight, bias, B, T, C);
const int block_size = 1024;
const int N = B * T;
const int grid_size = CEIL_DIV(N * 32, block_size);
layernorm_forward_kernel3<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);
cudaCheck(cudaGetLastError());
}

Expand Down

0 comments on commit 2c81198

Please sign in to comment.