Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dev/cuda] Added warpsize as a constant expr for dev/cuda files #438

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ version 11 is kernel 10 skipping FP16/FP32 conversions (full FP16/BF16 network)
// CUDA & cuDNN setup
static bool first_run_validation = true; // always run e.g. permute on 1st run

// WarpSize set as the compile time constant, this allows the compiler to optimize
constexpr int WARP_SIZE = 32;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably go into common.h


#ifdef ENABLE_CUDNN
#include <cudnn_frontend.h>
namespace fe = cudnn_frontend;
Expand Down Expand Up @@ -251,11 +254,11 @@ __global__ void softmax_forward_kernel4(float* out, const float* inp, int N, int
extern __shared__ float shared[];
int idx = blockIdx.x;
int tid = threadIdx.x;
int warpId = threadIdx.x / 32; // warp index within a block
int laneId = threadIdx.x % 32; // thread index within a warp
int warpId = threadIdx.x / WARP_SIZE; // warp index within a block
int laneId = threadIdx.x % WARP_SIZE; // thread index within a warp

// the number of warps per block. recall that blockDim.x is block_size
int warpsPerBlock = blockDim.x / 32;
int warpsPerBlock = blockDim.x / WARP_SIZE;

// shared[] must be allocated to have 2 * warpsPerBlock elements
// first half for max values, the second half for sum values
Expand Down
44 changes: 24 additions & 20 deletions dev/cuda/classifier_fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ typedef float floatX;
#endif
typedef Packed128<floatX> x128;


// WarpSize set as the compile time constant, this allows the compiler to optimize
constexpr int WARP_SIZE = 32;

// ----------------------------------------------------------------------------
// CPU code reference

Expand Down Expand Up @@ -218,11 +222,11 @@ __device__ SoftmaxParams prepare_softmax_blockwide(cg::thread_block_tile<32>& wa
// two reductions of up to 1024 threads:
// 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)
// this results in much cleaner assembly than a multi-warp cg::reduce
__shared__ float shared_maxval[32];
__shared__ float shared_sumval[32];
int num_warps = blockDim.x / 32;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
__shared__ float shared_maxval[WARP_SIZE];
__shared__ float shared_sumval[WARP_SIZE];
int num_warps = blockDim.x / WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;

// reduce maxval within each warp
float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater<float>{});
Expand Down Expand Up @@ -317,11 +321,11 @@ __device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_til
// two reductions of up to 1024 threads:
// 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)
// this results in much cleaner assembly than a multi-warp cg::reduce
__shared__ float shared_maxval[32];
__shared__ float shared_sumval[32];
int num_warps = blockDim.x / 32;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
__shared__ float shared_maxval[WARP_SIZE];
__shared__ float shared_sumval[WARP_SIZE];
int num_warps = blockDim.x / WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;

// reduce maxval within each warp
float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater<float>{});
Expand Down Expand Up @@ -409,11 +413,11 @@ __device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp,
// two reductions of up to 1024 threads:
// 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)
// this results in much cleaner assembly than a multi-warp cg::reduce
__shared__ float shared_maxval[32];
__shared__ float shared_sumval[32];
int num_warps = blockDim.x / 32;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
__shared__ float shared_maxval[WARP_SIZE];
__shared__ float shared_sumval[WARP_SIZE];
int num_warps = blockDim.x / WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;

// reduce maxval within each warp
float warp_maxval = warpReduceMax(thread_maxval);
Expand Down Expand Up @@ -495,10 +499,10 @@ template<reduction_func_t warp_reduction>
__device__ float blockReduce(float val, bool final_sync=false, float out_of_bounds=0.0f) {
// two reductions of up to 1024 threads:
// 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)
__shared__ float shared_val[32];
const int lane_id = threadIdx.x % 32;
const int warp_id = threadIdx.x / 32;
const int num_warps = blockDim.x / 32;
__shared__ float shared_val[WARP_SIZE];
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;

float warp_val = warp_reduction(val);
if (lane_id == 0) { shared_val[warp_id] = warp_val; }
Expand Down Expand Up @@ -630,7 +634,7 @@ void fused_classifier1(float* dlogits, float* losses,
const int N = B * T; // total number of rows in the input
// how many rows of the input can each block of threads process?
// e.g. in block_size=128, 4 rows get handled by 4 warps (of 32 threads each)
const int rows_per_block = block_size / 32;
const int rows_per_block = block_size / WARP_SIZE;
const int grid_size = N / rows_per_block; // total number of blocks needed
fused_classifier_kernel1<<<grid_size, block_size>>>(dlogits, losses, logits, dlosses, targets, B, T, V, P);
cudaCheck(cudaGetLastError());
Expand Down
14 changes: 9 additions & 5 deletions dev/cuda/layernorm_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ verstion 5 allocates blocks per row instead of warps per row, same alg as 4 othe
#include <cooperative_groups/reduce.h>
#include "common.h"


// WarpSize set as the compile time constant, this allows the compiler to optimize
constexpr int WARP_SIZE = 32;

// ----------------------------------------------------------------------------
// CPU code reference

Expand Down Expand Up @@ -285,11 +289,11 @@ __global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __rest
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
__shared__ float shared_sum[32]; // block_size max is 1024 = 32 * 32 warps
__shared__ float shared_sum2[32]; // warps will be writing into shared memeory after warp-reduce
int num_warps = blockDim.x / 32;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
__shared__ float shared_sum[WARP_SIZE]; // block_size max is 1024 = 32 * 32 warps
__shared__ float shared_sum2[WARP_SIZE]; // warps will be writing into shared memeory after warp-reduce
int num_warps = blockDim.x / WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int idx = blockIdx.x; // simpoy one block per row
// the row of input that this group of threads is responsible for
const float* x = inp + idx * C;
Expand Down