Skip to content

Commit

Permalink
speed up the backward bias kernel by 45% and speed up the full runnin…
Browse files Browse the repository at this point in the history
…g time by 1%
  • Loading branch information
karpathy committed Apr 20, 2024
1 parent 9b722ce commit 8488669
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 27 deletions.
255 changes: 255 additions & 0 deletions dev/cuda/matmul_backward_bias.cu
@@ -0,0 +1,255 @@
/*
Kernels for matmul backward pass bias only.
Compile example:
nvcc -O3 matmul_backward_bias.cu -lineinfo -o matmul_backward_bias
./matmul_backward_bias 1
./matmul_backward_bias 2
./matmul_backward_bias 3
ncu:
sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1
*/

#include <stdio.h>
#include <stdlib.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <omp.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "common.h"

// ----------------------------------------------------------------------------
// CPU code reference

void matmul_backward_bias_cpu(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight,
int B, int T, int C, int OC) {
for (int o = 0; o < OC; o++) {
double sum = 0.0;
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
sum += dout_bt[o];
}
}
dbias[o] = sum;
}
}

// ----------------------------------------------------------------------------
// GPU kernels

__global__ void matmul_backward_bias_kernel1(float* dbias, const float* dout, int B, int T, int OC) {
extern __shared__ float shared[];
int o = blockIdx.x; // range [0, OC)
int tid = threadIdx.x; // range [0, block_size)
int block_size = blockDim.x;
const float* x = dout + o;
// thread coarsening
double sum = 0.0;
for (int i = tid; i < B * T; i += block_size) {
sum += x[i * OC];
}
shared[tid] = (float) sum;
__syncthreads();
// reductions
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
__syncthreads();
if (tid < stride) {
shared[tid] += shared[tid + stride];
}
}
// write the final result (at thread 0) to global memory
if (tid == 0) {
dbias[o] += shared[0];
}
}

// cooperative groups solution, one warp per output channel
__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) {
// dout is (B, T, OC), dbias is (OC)
// e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
// meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3)
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= OC) { return; }
int BT = B * T; // number of elements to reduce in total, per channel
// first, thread coarsening to sum reduce the problem size from B*T to 32
float sum = 0.0f;
for(int i = warp.thread_rank(); i < BT; i += warp.size()) {
sum += dout[i * OC + idx];
}
// now do a warp-level reduce to get the sum across the 32 threads in this warp
sum = cg::reduce(warp, sum, cg::plus<float>{});
// write the result to output (global memory)
if(warp.thread_rank() == 0) {
dbias[idx] += sum;
}
}

__global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, int B, int T, int OC) {
// dout is (B, T, OC), dbias is (OC)
// in this version of the kernel the entire block of block_size is dedicated to one output channel
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
__shared__ float shared_sum[32]; // block_size max is 1024 = 32 * 32 warps
int BT = B * T; // number of elements to reduce in total, per channel
int num_warps = blockDim.x / 32;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
int idx = blockIdx.x; // simply one block per row
// round 1: thread coarsening to reduce the problem size from B*T to 32
float thread_sum = 0.0f;
for(int i = threadIdx.x; i < BT; i += blockDim.x) {
thread_sum += dout[i * OC + idx];
}
// now do a warp-level reduce to get the sum across the 32 threads in each warp
float warp_sum = cg::reduce(warp, thread_sum, cg::plus<float>{});
// store the warp sum in shared memory (we could have lane_id == 0 guard but not needed)
shared_sum[warp_id] = warp_sum;
__syncthreads();
// load results from shared memory to threads, pad with zeros for threads that are out of bounds
warp_sum = (lane_id < num_warps) ? shared_sum[lane_id] : 0.0f;
// now reduce the warp-level reductions
float block_sum = cg::reduce(warp, warp_sum, cg::plus<float>{}); // sum(x)
// write the result to output (global memory)
if(threadIdx.x == 0) {
dbias[idx] += block_sum;
}
}

// ----------------------------------------------------------------------------
// kernel launcher

// version1: simple cuBLAS calls
void matmul_backward_bias1(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
int B, int T, int C, int OC, int block_size) {
dim3 block_dim(block_size);
dim3 grid_dim(OC);
size_t shared_mem_size = block_size * sizeof(float);
matmul_backward_bias_kernel1<<<grid_dim, block_dim, shared_mem_size>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias2(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
int B, int T, int C, int OC, int block_size) {
// block_size 512 seems best
const int grid_size = ceil_div(OC * 32, block_size);
matmul_backward_bias_kernel2<<<grid_size, block_size>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias3(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
int B, int T, int C, int OC, int block_size) {
// block_size 256 seems best
matmul_backward_bias_kernel3<<<OC, block_size>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias(int kernel_num,
float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
int B, int T, int C, int OC, int block_size) {
switch (kernel_num) {
case 1:
matmul_backward_bias1(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
case 2:
matmul_backward_bias2(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
case 3:
matmul_backward_bias3(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
}
}

// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);

int B = 8;
int T = 1024;
int C = 768;
int OC = 768 * 4; // expansion of 4, e.g. in the MLP

// set up the device
int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

// read kernel_num from command line
int kernel_num = 1;
if (argc > 1) {
kernel_num = atoi(argv[1]);
}
printf("Using kernel %d\n", kernel_num);

// create host memory of random numbers
float* dbias = make_zeros_float(OC);
float* dout = make_random_float(B * T * OC);

// move to GPU
float* d_dbias;
float* d_dout;
cudaCheck(cudaMalloc(&d_dbias, OC * sizeof(float)));
cudaCheck(cudaMalloc(&d_dout, B * T * OC * sizeof(float)));
cudaCheck(cudaMemcpy(d_dbias, dbias, OC * sizeof(float), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpy(d_dout, dout, B * T * OC * sizeof(float), cudaMemcpyHostToDevice));

// ncu debugging / profiling, do a single call
// int block_size_debug;
// if (kernel_num == 1) { block_size_debug = 512;
// } else if (kernel_num == 2) { block_size_debug = 512;
// } else { block_size_debug = 256; }
// printf("kernel %d, block_size %d\n", kernel_num, block_size_debug);
// matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, block_size_debug);
// exit(EXIT_SUCCESS);

int block_sizes[] = {32, 64, 128, 256, 512, 1024};

// calculate the CPU reference
matmul_backward_bias_cpu(NULL, NULL, dbias, dout, NULL, NULL, B, T, C, OC);

for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
// memset the bias to zero
cudaCheck(cudaMemset(d_dbias, 0, OC * sizeof(float)));
// calculate the GPU version
matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, 128);
// compare
printf("Checking correctness...\n");
validate_result(d_dbias, dbias, "dbias", OC, 1e-3f);
printf("All results match for block_size=%d.\n\n", block_size);
}

// now benchmark the kernel
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
float *d_dinp, *d_dweight, *d_inp, *d_weight, *d_ones;
int repeat_times = 2000;
float elapsed_time = benchmark_kernel(repeat_times, matmul_backward_bias, kernel_num,
d_dinp, d_dweight, d_dbias, d_dout, d_inp, d_weight, d_ones,
B, T, C, OC, block_size);
printf("block_size %d time %.4f ms\n", block_size, elapsed_time);
}

// cleanups
free(dbias);
free(dout);
cudaCheck(cudaFree(d_dbias));
cudaCheck(cudaFree(d_dout));

return 0;
}
50 changes: 23 additions & 27 deletions train_gpt2.cu
Expand Up @@ -532,29 +532,27 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int
}
}

__global__ void matmul_backward_bias_kernel_faster(float* dbias, const float* dout, int B, int T, int OC) {
extern __shared__ float shared[];
int o = blockIdx.x; // range [0, OC)
int tid = threadIdx.x; // range [0, block_size)
int block_size = blockDim.x;
const float* x = dout + o;
// thread coarsening
double sum = 0.0;
for (int i = tid; i < B * T; i += block_size) {
sum += x[i * OC];
}
shared[tid] = (float) sum;
__syncthreads();
// reductions
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
__syncthreads();
if (tid < stride) {
shared[tid] += shared[tid + stride];
}
// cooperative groups solution, one warp per output channel
__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) {
// dout is (B, T, OC), dbias is (OC)
// e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
// meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3)
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= OC) { return; }
int BT = B * T; // number of elements to reduce in total, per channel
// first, thread coarsening to sum reduce the problem size from B*T to 32
float sum = 0.0f;
for(int i = warp.thread_rank(); i < BT; i += warp.size()) {
sum += dout[i * OC + idx];
}
// write the final result (at thread 0) to global memory
if (tid == 0) {
dbias[o] += shared[0];
// now do a warp-level reduce to get the sum across the 32 threads in this warp
sum = cg::reduce(warp, sum, cg::plus<float>{});
// write the result to output (global memory)
if(warp.thread_rank() == 0) {
dbias[idx] += sum;
}
}

Expand Down Expand Up @@ -971,11 +969,9 @@ void matmul_backward(float* dinp, float* dweight, float* dbias,
cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C));
// backward to bias, if given, does a +=
if (dbias != NULL) {
const int block_size=512;
dim3 block_dim(block_size);
dim3 grid_dim(OC);
size_t shared_mem_size = block_size * sizeof(float);
matmul_backward_bias_kernel_faster<<<grid_dim, block_dim, shared_mem_size>>>(dbias, dout, B, T, OC);
const int block_size = 512;
const int grid_size = CEIL_DIV(OC * 32, block_size);
matmul_backward_bias_kernel2<<<grid_size, block_size>>>(dbias, dout, B, T, OC);
cudaCheck(cudaGetLastError());
}
}
Expand Down

0 comments on commit 8488669

Please sign in to comment.