Skip to content

Commit

Permalink
Merge branch 'matmul-backward-bias' of https://github.com/al0vya/llm.c
Browse files Browse the repository at this point in the history
…into al0vya-matmul-backward-bias
  • Loading branch information
karpathy committed Apr 22, 2024
2 parents f813d63 + 35393b4 commit a42f739
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 25 deletions.
53 changes: 52 additions & 1 deletion dev/cuda/matmul_backward_bias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,47 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in
}
}

// this kernel essentially performs a column-wise reduction over dout,
// which in pytorch would simply look like: dbias = dout.sum((0,1))
// the philosophy of this kernel is to employ one block to reduce along
// several columns, whereby each block has a "width" of 32 columns to ensure
// coalesced access. near the end of the column-wise reduction, we accumulate
// the reductions performed by the warps in each block via shared memory
__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {

const int vstep = blockDim.x / warpSize;
const int row = threadIdx.x >> 5; // basically warp_id
const int tl = blockIdx.x * warpSize;
const int lane = threadIdx.x & (warpSize - 1);

const float* dout_col = dout + tl + lane;

extern __shared__ float smem[];

float dout_sum = 0.0f;
// column reductions by looping through the rows:
// the loop should not exceed B * T rows
for (int j = row; j < B * T; j += vstep) {
dout_sum += dout_col[j * OC];
}

smem[lane + row * warpSize] = dout_sum;

// our kernel assures that entire blocks are running
// inside the loop, so we can safely call sync I believe
__syncthreads();

dout_sum = 0.0f;

if (row == 0) {
for (int j = 0; j < vstep; j++) {
dout_sum += smem[lane + j * warpSize];
}

dbias[tl + lane] += dout_sum;
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -152,6 +193,13 @@ void matmul_backward_bias3(float* dinp, float* dweight, float* dbias,
matmul_backward_bias_kernel3<<<OC, block_size>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias4(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
int B, int T, int C, int OC, int block_size) {
const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work
matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias(int kernel_num,
float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
Expand All @@ -166,6 +214,9 @@ void matmul_backward_bias(int kernel_num,
case 3:
matmul_backward_bias3(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
case 4:
matmul_backward_bias4(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down Expand Up @@ -230,7 +281,7 @@ int main(int argc, char **argv) {
matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, 128);
// compare
printf("Checking correctness...\n");
validate_result(d_dbias, dbias, "dbias", OC, 1e-3f);
validate_result(d_dbias, dbias, "dbias", OC, 5e-3f);
printf("All results match for block_size=%d.\n\n", block_size);
}

Expand Down
65 changes: 41 additions & 24 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -532,27 +532,44 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int
}
}

// cooperative groups solution, one warp per output channel
__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) {
// dout is (B, T, OC), dbias is (OC)
// e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
// meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3)
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= OC) { return; }
int BT = B * T; // number of elements to reduce in total, per channel
// first, thread coarsening to sum reduce the problem size from B*T to 32
float sum = 0.0f;
for(int i = warp.thread_rank(); i < BT; i += warp.size()) {
sum += dout[i * OC + idx];
}
// now do a warp-level reduce to get the sum across the 32 threads in this warp
sum = cg::reduce(warp, sum, cg::plus<float>{});
// write the result to output (global memory)
if(warp.thread_rank() == 0) {
dbias[idx] += sum;
// this kernel essentially performs a column-wise reduction over dout,
// which in pytorch would simply look like: dbias = dout.sum((0,1))
// the philosophy of this kernel is to employ one block to reduce along
// several columns, whereby each block has a "width" of 32 columns to ensure
// coalesced access. near the end of the column-wise reduction, we accumulate
// the reductions performed by the warps in each block via shared memory
__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {

const int vstep = blockDim.x / warpSize;
const int row = threadIdx.x >> 5; // basically warp_id
const int tl = blockIdx.x * warpSize;
const int lane = threadIdx.x & (warpSize - 1);

const float* dout_col = dout + tl + lane;

extern __shared__ float smem[];

float dout_sum = 0.0f;
// column reductions by looping through the rows:
// the loop should not exceed B * T rows
for (int j = row; j < B * T; j += vstep) {
dout_sum += dout_col[j * OC];
}

smem[lane + row * warpSize] = dout_sum;

// our kernel assures that entire blocks are running
// inside the loop, so we can safely call sync I believe
__syncthreads();

dout_sum = 0.0f;

if (row == 0) {
for (int j = 0; j < vstep; j++) {
dout_sum += smem[lane + j * warpSize];
}

dbias[tl + lane] += dout_sum;
}
}

Expand Down Expand Up @@ -973,9 +990,9 @@ void matmul_backward(float* dinp, float* dweight, float* dbias,
cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C));
// backward to bias, if given, does a +=
if (dbias != NULL) {
const int block_size = 512;
const int grid_size = CEIL_DIV(OC * 32, block_size);
matmul_backward_bias_kernel2<<<grid_size, block_size>>>(dbias, dout, B, T, OC);
const int block_size = 1024;
const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work
matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);
cudaCheck(cudaGetLastError());
}
}
Expand Down

0 comments on commit a42f739

Please sign in to comment.