Skip to content

Commit

Permalink
add more comments to explain the philosophy behind the kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
al0vya committed Apr 22, 2024
1 parent b82ec20 commit 35393b4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
10 changes: 6 additions & 4 deletions dev/cuda/matmul_backward_bias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,12 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in
}
}

// this kernel essentially performs a column-wise reduction over dout
// the philosophy of this kernel is to employ one block to reduce
// along several columns and then to share and accumulate the
// reductions performed by different warps via shared memory
// this kernel essentially performs a column-wise reduction over dout,
// which in pytorch would simply look like: dbias = dout.sum((0,1))
// the philosophy of this kernel is to employ one block to reduce along
// several columns, whereby each block has a "width" of 32 columns to ensure
// coalesced access. near the end of the column-wise reduction, we accumulate
// the reductions performed by the warps in each block via shared memory
__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {

const int vstep = blockDim.x / warpSize;
Expand Down
10 changes: 6 additions & 4 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,12 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int
}
}

// this kernel essentially performs a column-wise reduction over dout
// the philosophy of the kernel is to employ one block to reduce
// along several columns and then to share and accumulate the
// reductions performed by different warps via shared memory
// this kernel essentially performs a column-wise reduction over dout,
// which in pytorch would simply look like: dbias = dout.sum((0,1))
// the philosophy of this kernel is to employ one block to reduce along
// several columns, whereby each block has a "width" of 32 columns to ensure
// coalesced access. near the end of the column-wise reduction, we accumulate
// the reductions performed by the warps in each block via shared memory
__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {

const int vstep = blockDim.x / warpSize;
Expand Down

0 comments on commit 35393b4

Please sign in to comment.