diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index 15753d8bd..c550c81f8 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -124,6 +124,47 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in } } +// this kernel essentially performs a column-wise reduction over dout, +// which in pytorch would simply look like: dbias = dout.sum((0,1)) +// the philosophy of this kernel is to employ one block to reduce along +// several columns, whereby each block has a "width" of 32 columns to ensure +// coalesced access. near the end of the column-wise reduction, we accumulate +// the reductions performed by the warps in each block via shared memory +__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { + + const int vstep = blockDim.x / warpSize; + const int row = threadIdx.x >> 5; // basically warp_id + const int tl = blockIdx.x * warpSize; + const int lane = threadIdx.x & (warpSize - 1); + + const float* dout_col = dout + tl + lane; + + extern __shared__ float smem[]; + + float dout_sum = 0.0f; + // column reductions by looping through the rows: + // the loop should not exceed B * T rows + for (int j = row; j < B * T; j += vstep) { + dout_sum += dout_col[j * OC]; + } + + smem[lane + row * warpSize] = dout_sum; + + // our kernel assures that entire blocks are running + // inside the loop, so we can safely call sync I believe + __syncthreads(); + + dout_sum = 0.0f; + + if (row == 0) { + for (int j = 0; j < vstep; j++) { + dout_sum += smem[lane + j * warpSize]; + } + + dbias[tl + lane] += dout_sum; + } +} + // ---------------------------------------------------------------------------- // kernel launcher @@ -152,6 +193,13 @@ void matmul_backward_bias3(float* dinp, float* dweight, float* dbias, matmul_backward_bias_kernel3<<>>(dbias, dout, B, T, OC); } +void matmul_backward_bias4(float* dinp, float* dweight, float* dbias, + float* dout, float* inp, float* weight, float* ones, + int B, int T, int C, int OC, int block_size) { + const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work + matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); +} + void matmul_backward_bias(int kernel_num, float* dinp, float* dweight, float* dbias, float* dout, float* inp, float* weight, float* ones, @@ -166,6 +214,9 @@ void matmul_backward_bias(int kernel_num, case 3: matmul_backward_bias3(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); break; + case 4: + matmul_backward_bias4(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -230,7 +281,7 @@ int main(int argc, char **argv) { matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, 128); // compare printf("Checking correctness...\n"); - validate_result(d_dbias, dbias, "dbias", OC, 1e-3f); + validate_result(d_dbias, dbias, "dbias", OC, 5e-3f); printf("All results match for block_size=%d.\n\n", block_size); } diff --git a/train_gpt2.cu b/train_gpt2.cu index de7971392..079ecb42b 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -532,27 +532,44 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int } } -// cooperative groups solution, one warp per output channel -__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) { - // dout is (B, T, OC), dbias is (OC) - // e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel - namespace cg = cooperative_groups; - cg::thread_block block = cg::this_thread_block(); - cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); - // meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3) - int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); - if(idx >= OC) { return; } - int BT = B * T; // number of elements to reduce in total, per channel - // first, thread coarsening to sum reduce the problem size from B*T to 32 - float sum = 0.0f; - for(int i = warp.thread_rank(); i < BT; i += warp.size()) { - sum += dout[i * OC + idx]; - } - // now do a warp-level reduce to get the sum across the 32 threads in this warp - sum = cg::reduce(warp, sum, cg::plus{}); - // write the result to output (global memory) - if(warp.thread_rank() == 0) { - dbias[idx] += sum; +// this kernel essentially performs a column-wise reduction over dout, +// which in pytorch would simply look like: dbias = dout.sum((0,1)) +// the philosophy of this kernel is to employ one block to reduce along +// several columns, whereby each block has a "width" of 32 columns to ensure +// coalesced access. near the end of the column-wise reduction, we accumulate +// the reductions performed by the warps in each block via shared memory +__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { + + const int vstep = blockDim.x / warpSize; + const int row = threadIdx.x >> 5; // basically warp_id + const int tl = blockIdx.x * warpSize; + const int lane = threadIdx.x & (warpSize - 1); + + const float* dout_col = dout + tl + lane; + + extern __shared__ float smem[]; + + float dout_sum = 0.0f; + // column reductions by looping through the rows: + // the loop should not exceed B * T rows + for (int j = row; j < B * T; j += vstep) { + dout_sum += dout_col[j * OC]; + } + + smem[lane + row * warpSize] = dout_sum; + + // our kernel assures that entire blocks are running + // inside the loop, so we can safely call sync I believe + __syncthreads(); + + dout_sum = 0.0f; + + if (row == 0) { + for (int j = 0; j < vstep; j++) { + dout_sum += smem[lane + j * warpSize]; + } + + dbias[tl + lane] += dout_sum; } } @@ -973,9 +990,9 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C)); // backward to bias, if given, does a += if (dbias != NULL) { - const int block_size = 512; - const int grid_size = CEIL_DIV(OC * 32, block_size); - matmul_backward_bias_kernel2<<>>(dbias, dout, B, T, OC); + const int block_size = 1024; + const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work + matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); cudaCheck(cudaGetLastError()); } }