From da987fa1ac37d94f144bb2eb7a80c32a152922d4 Mon Sep 17 00:00:00 2001 From: ademeure Date: Sun, 21 Apr 2024 19:38:09 +0100 Subject: [PATCH 1/6] Initial working FP16 cuDNN Forward Attention (+ FP16 port of kernel 4) --- dev/cuda/attention_forward.cu | 376 ++++++++++++++++++++++++++++++++-- 1 file changed, 357 insertions(+), 19 deletions(-) diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index 695f29981..ec3b2cb64 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -1,8 +1,12 @@ /* Kernels for attention forward pass. +If you do not have CUDNN, you can remove ENABLE_CUDNN to run the other kernels +You need cuDNN from: https://developer.nvidia.com/cudnn +And the cuDNN front-end from: https://github.com/NVIDIA/cudnn-frontend/tree/main + Compile example: -nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward -lcublas +nvcc -I/path/to/cudnn-frontend/include -O3 --use_fast_math -lcublas -lcudnn attention_forward.cu -o attention_forward version 1 is naive port from CPU code to kernel, parallelize over batch, time, heads only ./attention_forward 1 @@ -23,8 +27,20 @@ this turns out to be ~20X faster than (1) nice version 4 is a further optimized kernel that fuses the scale operation, uses a directly autoregressive softmax, and uses the online softmax algorithm. ./attention_forward 4 -*/ +version 5 is a FP16 version of kernel 4 +./attention_forward 5 + +version 6 is kernel 5 skipping (un)permute (unrealistic but useful comparison point) + +version 10 is using cuDNN Flash Attention using FP16 or BF16, see: +https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md +./attention_forward 10 + +version 11 is kernel 10 skipping FP16/FP32 conversions (requires fully FP16 network) +./attention_forward 11 +*/ +#define ENABLE_CUDNN #include #include #include @@ -35,11 +51,24 @@ uses a directly autoregressive softmax, and uses the online softmax algorithm. #include #include "common.h" +typedef half lowp_float; // or __nv_bfloat16 + +#ifdef ENABLE_CUDNN +#include +namespace fe = cudnn_frontend; +#define CUDNN_16BIT fe::DataType_t::HALF // or BFLOAT16 (cuDNN kernels only) + +static cudnnHandle_t cudnn_handle; +static size_t cudnn_workspace_size = 32 * 1024 * 1024; +static void* cudnn_workspace = NULL; + +#define checkCudaErr(err) assert((int)err == 0); +#define checkCudnnErr(err) assert((int)err == 0); +#endif // ENABLE_CUDNN // ---------------------------------------------------------------------------- // CUDA setup - static cublasHandle_t cublas_handle; - +static bool first_run_validation = true; // always run e.g. permute on 1st run // ---------------------------------------------------------------------------- // CPU code reference @@ -862,18 +891,303 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); } -void attention_forward5(float* out, float* preatt, float* att, + +__global__ void softmax_forward_kernel5_lowp(lowp_float* out, float inv_temperature, + const lowp_float* inp, int N, int T) { + // inp, out shape: (N, T, T), where N = B * NH + // fuses the multiplication by scale inside attention + // directly autoregressive, so we only compute the lower triangular part + // uses the online softmax algorithm + assert(T % 4 == 0); + namespace cg = cooperative_groups; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + if(idx >= N * T) { + return; + } + int own_pos = idx % T; + int pos_by_4 = own_pos / 4; + + // one row of inp, i.e. inp[idx, :] of shape (T,) + const lowp_float* x = inp + idx * T; + + // not INF, so we don't get NaNs accidentally when subtracting two values. + float maxval = -FLT_MAX; + float sumval = 0.0f; + + // Same thing but without float4, one at a time + for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) { + float old_maxval = maxval; + for(int k = 0; k < 4; ++k) { + maxval = fmaxf(maxval, (float)x[4*i + k]); + } + sumval *= expf(inv_temperature * (old_maxval - maxval)); + for(int k = 0; k < 4; ++k) { + sumval += expf(inv_temperature * ((float)x[4*i + k] - maxval)); + } + } + + if(4*pos_by_4 + warp.thread_rank() <= own_pos) { + float old_maxval = maxval; + maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]); + sumval *= expf(inv_temperature * (old_maxval - maxval)); + sumval += expf(inv_temperature * ((float)x[4*pos_by_4 + warp.thread_rank()] - maxval)); + } + + float global_maxval = cg::reduce(warp, maxval, cg::greater{}); + sumval *= expf(inv_temperature * (maxval - global_maxval)); + + float sum = cg::reduce(warp, sumval, cg::plus{}); + float norm = 1.f / sum; + + // divide the whole row by the sum + for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) { + // recalculation is faster than doing the round-trip through memory. + float ev = expf(inv_temperature * ((float)__ldcs(x + i) - global_maxval)); + __stcs(out + idx * T + i, (lowp_float)(ev * norm)); + } +} + +__global__ void permute_kernel_lowp(lowp_float* q, lowp_float* k, lowp_float* v, + const float* inp, + int B, int N, int NH, int d) { + // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d) + // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_] + if (idx < B * NH * N * d) { + int b = idx / (NH * N * d); + int rest = idx % (NH * N * d); + int nh_ = rest / (N * d); + rest = rest % (N * d); + int n = rest / d; + int d_ = rest % d; + + int inp_idx = \ + (b * N * 3 * NH * d) + + (n * 3 * NH * d) + + (0 * NH * d) + + (nh_ * d) + + d_; + + q[idx] = (lowp_float)inp[inp_idx]; + k[idx] = (lowp_float)inp[inp_idx + NH * d]; + v[idx] = (lowp_float)inp[inp_idx + 2 * (NH * d)]; + } +} + +__global__ void unpermute_kernel_lowp(const lowp_float* inp, float *out, int B, int N, int NH, int d) { + // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // out[b][n][nh_][d_] <- inp[b][nh_][n][d_] + if (idx < B * NH * N * d) { + int b = idx / (NH * N * d); + int rest = idx % (NH * N * d); + int nh_ = rest / (N * d); + rest = rest % (N * d); + int n = rest / d; + int d_ = rest % d; + + int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; + out[other_idx] = (float)inp[idx]; + } +} + +void attention_forward5(float* out, lowp_float* vaccum, lowp_float* qkvr, lowp_float* preatt, lowp_float* att, const float* inp, int B, int T, int C, int NH, - const int block_size) { - // attention calculation - int x_blocks = ceil_div(T, block_size / 32); - attention_forward_fused1<<>>(out, preatt, att, inp, B, T, C, NH); + const int block_size, bool skip_permute=false) { + // FP16 version of kernel 4 (with permute/unpermute doing FP32<->FP16) + // That permute can be skipped on perf runs to analyse its performance impact + // inp is (B, T, 3C) QKV + // preatt, att are (B, NH, T, T) + // output is (B, T, C) + + // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) + int HS = C / NH; // head size + lowp_float *q = qkvr + 0 * B * T * C; + lowp_float *k = qkvr + 1 * B * T * C; + lowp_float* v = qkvr + 2 * B * T * C; + + int total_threads = B * NH * T * HS; + int num_blocks = ceil_div(total_threads, block_size); + if (!skip_permute || first_run_validation) { + permute_kernel_lowp<<>>(q, k, v, inp, B, T, NH, HS); + } + + // batched matrix multiply with cuBLAS + const lowp_float alpha = (lowp_float)1.0f; + const lowp_float beta = (lowp_float)0.0f; + cublasCheck(cublasHgemmStridedBatched(cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + T, T, HS, + &alpha, + k, HS, T * HS, + q, HS, T * HS, + &beta, + preatt, T, T * T, + B * NH)); + + // multiply all elements of preatt elementwise by scale + lowp_float scale = 1.0 / sqrtf(HS); + int softmax_block_size = 256; + int grid_size = ceil_div(B * NH * T * 32, softmax_block_size); + softmax_forward_kernel5_lowp<<>>(att, scale, preatt, B * NH, T); + + // new approach: first cuBLAS another batched matmul + // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) + cublasCheck(cublasHgemmStridedBatched(cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + HS, T, T, + &alpha, + v, HS, T * HS, + att, T, T * T, + &beta, + vaccum, HS, T * HS, + B * NH)); + + // now unpermute + // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + num_blocks = ceil_div(B * T * C, block_size); + if(!skip_permute || first_run_validation) { + unpermute_kernel_lowp<<>>((lowp_float*)vaccum, out, B, T, NH, HS); + } } +#ifdef ENABLE_CUDNN +using graph_and_tensors = std::tuple, + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::shared_ptr, // Attn_scale, + std::shared_ptr, // O + std::shared_ptr>; // Stats +using cache_type = std::unordered_map; + +template +auto lookup_cache_or_build_graph(Args... args) { + static cache_type user_maintained_cache; + auto [b, h, s_qkv, d, is_inference] = std::make_tuple(args...); + + auto graph = std::make_shared(); + graph->set_io_data_type(CUDNN_16BIT) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + // (B, N, 3, NH, d) + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_qkv, d}) + .set_stride({3 * h * d * s_qkv, d, 3 * h * d, 1})); + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, h, s_qkv, d}) + .set_stride({3 * h * d * s_qkv, d, 3 * h * d, 1})); + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, h, s_qkv, d}) + .set_stride({3 * h * d * s_qkv, d, 3 * h * d, 1})); + auto attn_scale = graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + auto sdpa_options = fe::graph::SDPA_attributes().set_name("flash_attention"); + sdpa_options.set_is_inference(is_inference); + sdpa_options.set_attn_scale(attn_scale); + sdpa_options.set_causal_mask(true); + + auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options); + + O->set_output(true).set_dim({b, h, s_qkv, d}).set_stride({h * d * s_qkv, d, h * d, 1}); + assert(stats == nullptr || is_inference == false); + if (!is_inference) { + stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); + } + + assert(graph->validate().is_good()); + auto key = graph->key(); + auto it = user_maintained_cache.find(key); + if (it != user_maintained_cache.end()) { + return it->second; + } + + assert(graph->build_operation_graph(cudnn_handle).is_good()); + auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); + assert(graph->check_support(cudnn_handle).is_good()); + assert(graph->build_plans(cudnn_handle).is_good()); + + auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats); + user_maintained_cache.insert({key, tuple}); + return tuple; +} + +__global__ void fp32_to_lowp_kernel(lowp_float* out, const float* inp) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + out[idx] = (lowp_float)inp[idx]; +} + +__global__ void lowp_to_fp32_kernel(const lowp_float* inp, float *out) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + out[idx] = (float)inp[idx]; +} + +void attention_forward10(float* out, // output: (B, T, NH, HS) + float* stats, // for use in backward pass: (B, NH, T) + lowp_float* vaccum, lowp_float* qkvr, + const float* inp, // input: (B, T, 3, NH, HS) QKV + int B, int T, int C, int NH, + bool skip_conversion=false) { + int64_t HS = C / NH; // number of features per head + bool is_inference = stats != NULL; + float attn_scale_cpu = 1.0 / sqrtf(HS); + + const int block_size = 64; + int total_threads = B * T * C * 3; + assert(total_threads % block_size == 0); + int num_blocks = total_threads / block_size; + + if (!skip_conversion || first_run_validation) { + fp32_to_lowp_kernel<<>>(qkvr, inp); + } + + auto [graph, Q, K, V, attn_scale, O, softmax_stats] = + lookup_cache_or_build_graph(B, NH, T, HS, is_inference); + + //// Build variant pack + void* devPtrQ = qkvr; + void* devPtrK = (qkvr + NH * HS); + void* devPtrV = (qkvr + 2 * NH * HS); + void* devPtrO = (void*)vaccum; + + std::unordered_map, void*> variant_pack = { + {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}}; + if (is_inference == false) { + variant_pack[softmax_stats] = (void*)stats; + } + + assert(graph->get_workspace_size() <= cudnn_workspace_size); + assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); + + total_threads = B * T * C; + assert(total_threads % block_size == 0); + num_blocks = total_threads / block_size; + if (!skip_conversion || first_run_validation) { + lowp_to_fp32_kernel<<>>(vaccum, out); + } +} + +#endif // ENABLE_CUDNN + // kernel version dispatch void attention_forward(int kernel_num, - float* out, float* vaccum, float* qkvr, float* preatt, float* att, + float* out, float* stats, float* vaccum, + float* qkvr, float* preatt, float* att, const float* inp, int B, int T, int C, int NH, const int block_size) { @@ -891,8 +1205,25 @@ void attention_forward(int kernel_num, attention_forward4(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size); break; case 5: - attention_forward5(out, preatt, att, inp, B, T, C, NH, block_size); + attention_forward5(out, (lowp_float*)vaccum, (lowp_float*)qkvr, + (lowp_float*)preatt, (lowp_float*)att, + inp, B, T, C, NH, block_size, false); + break; + case 6: // skip permutes for perf passes (to analyse perf as if in/out were truly 16-bit) + attention_forward5(out, (lowp_float*)vaccum, (lowp_float*)qkvr, + (lowp_float*)preatt, (lowp_float*)att, + inp, B, T, C, NH, block_size, true); + break; + #ifdef ENABLE_CUDNN + case 10: + attention_forward10(out, stats, (lowp_float*)vaccum, (lowp_float*)qkvr, + inp, B, T, C, NH, false); + break; + case 11: // skip permutes for perf passes (to analyse perf as if in/out were truly 16-bit) + attention_forward10(out, stats, (lowp_float*)vaccum, (lowp_float*)qkvr, + inp, B, T, C, NH, true); break; + #endif default: printf("Invalid kernel number\n"); exit(1); @@ -911,6 +1242,10 @@ int main(int argc, char **argv) { int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); cublasCreate(&cublas_handle); + #ifdef ENABLE_CUDNN + checkCudnnErr(cudnnCreate(&cudnn_handle)); + cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); + #endif // create host memory of random numbers float* out = (float*)malloc(B * T * C * sizeof(float)); @@ -920,12 +1255,14 @@ int main(int argc, char **argv) { // move to GPU float* d_out; + float* d_stats; // for cuDNN float* d_vaccum; float* d_qkvr; float* d_preatt; float* d_att; float* d_inp; cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float))); + cudaCheck(cudaMalloc(&d_stats, B * NH * T * sizeof(float))); cudaCheck(cudaMalloc(&d_vaccum, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float))); @@ -946,23 +1283,24 @@ int main(int argc, char **argv) { for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); - attention_forward(kernel_num, d_out, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size); + attention_forward(kernel_num, d_out, d_stats, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size); // all kernels should produce the correct output out - validate_result(d_out, out, "out", B * T * C, 1e-4f); + validate_result(d_out, out, "out", B * T * C, 1e-3f); // but as for preatt and att, things get a bit more complicated: - if (kernel_num != 2) { + if (kernel_num != 2 && kernel_num < 5) { // kernel 2 (knowingly) fails att/preatt because it uses a different algorithm // that estimates the softmax online and never materializes preatt/att - validate_result(d_att, att, "att", B * NH * T * T, 1e-4f); + validate_result(d_att, att, "att", B * NH * T * T, 1e-3f); } - if (kernel_num != 2 && kernel_num != 4 && kernel_num != 5) { + if (kernel_num != 2 && kernel_num < 4) { // kernel 4 (knowingly) fails preatt because it fuses the scale normalization // into the softmax, so preatt is off by 1.0f / sqrt(HS) // but att and out (checked below) should match. - validate_result(d_preatt, preatt, "preatt", B * NH * T * T, 1e-4f); + validate_result(d_preatt, preatt, "preatt", B * NH * T * T, 1e-3f); } } printf("All results match. Starting benchmarks.\n\n"); + first_run_validation = false; // benchmark speed of the kernel for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { @@ -970,8 +1308,8 @@ int main(int argc, char **argv) { int repeat_times = 100; float elapsed_time = benchmark_kernel(repeat_times, attention_forward, - kernel_num, d_out, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, - B, T, C, NH, block_size); + kernel_num, d_out, d_stats, d_vaccum, d_qkvr, d_preatt, d_att, + d_inp, B, T, C, NH, block_size); printf("block_size %4d | time %f ms\n", block_size, elapsed_time); } From 8a93ecf4470cacfbcd4dc02739fc0369367fee7e Mon Sep 17 00:00:00 2001 From: ademeure Date: Sun, 21 Apr 2024 21:30:57 +0100 Subject: [PATCH 2/6] WIP - cuBLAS Ex() not working as expected... --- dev/cuda/attention_forward.cu | 106 +++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 2 deletions(-) diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index ec3b2cb64..6db2bbb89 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -40,18 +40,22 @@ https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md version 11 is kernel 10 skipping FP16/FP32 conversions (requires fully FP16 network) ./attention_forward 11 */ -#define ENABLE_CUDNN +//#define ENABLE_CUDNN #include #include #include #include #include #include +#include #include #include #include "common.h" typedef half lowp_float; // or __nv_bfloat16 +#define CUBLAS_LOWP CUDA_R_16F +#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F +//#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F // 32F compute is as fast as 16F on A100/H100 #ifdef ENABLE_CUDNN #include @@ -857,6 +861,21 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f // batched matrix multiply with cuBLAS const float alpha = 1.0f; const float beta = 0.0f; + + /*cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + T, T, HS, + &alpha, + k, CUDA_R_32F, HS, T * HS, + q, CUDA_R_32F, HS, T * HS, + &beta, + preatt, CUDA_R_32F, T, T * T, + B * NH, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT));*/ + + //cudaCheck(cudaMemset(preatt, 0, B * NH * T * T * sizeof(float))); + cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, @@ -875,6 +894,18 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f // new approach: first cuBLAS another batched matmul // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) +/*cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + HS, T, T, + &alpha, + v, CUDA_R_32F, HS, T * HS, + att, CUDA_R_32F, T, T * T, + &beta, + vaccum, CUDA_R_32F, HS, T * HS, + B * NH, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT));*/ + cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, @@ -1018,9 +1049,65 @@ void attention_forward5(float* out, lowp_float* vaccum, lowp_float* qkvr, lowp_f permute_kernel_lowp<<>>(q, k, v, inp, B, T, NH, HS); } +/* +cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + long long int strideA, + const void *B, + cudaDataType_t Btype, + int ldb, + long long int strideB, + const void *beta, + void *C, + cudaDataType_t Ctype, + int ldc, + long long int strideC, + int batchCount, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo) + +cublasStatus_t cublasHgemmStridedBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, int n, int k, + const __half *alpha, + const __half *A, int lda, + long long int strideA, + const __half *B, int ldb, + long long int strideB, + const __half *beta, + __half *C, int ldc, + long long int strideC, + int batchCount) +*/ + // batched matrix multiply with cuBLAS const lowp_float alpha = (lowp_float)1.0f; const lowp_float beta = (lowp_float)0.0f; + + /*cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + T, T, HS, + &alpha, + k, CUBLAS_LOWP, HS, T * HS, + q, CUBLAS_LOWP, HS, T * HS, + &beta, + preatt, CUBLAS_LOWP, T, T * T, + B * NH, + CUBLAS_LOWP_COMPUTE, + CUBLAS_GEMM_DEFAULT));*/ + + // memset preatt - things don't break as much as they should... + //cudaCheck(cudaMemset(preatt, 0, B * NH * T * T * sizeof(lowp_float))); + cublasCheck(cublasHgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, @@ -1031,6 +1118,7 @@ void attention_forward5(float* out, lowp_float* vaccum, lowp_float* qkvr, lowp_f preatt, T, T * T, B * NH)); + // multiply all elements of preatt elementwise by scale lowp_float scale = 1.0 / sqrtf(HS); int softmax_block_size = 256; @@ -1039,6 +1127,19 @@ void attention_forward5(float* out, lowp_float* vaccum, lowp_float* qkvr, lowp_f // new approach: first cuBLAS another batched matmul // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) + + /*cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + HS, T, T, + &alpha, + v, CUBLAS_LOWP, HS, T * HS, + att, CUBLAS_LOWP, T, T * T, + &beta, + vaccum, CUBLAS_LOWP, HS, T * HS, + B * NH, + CUBLAS_LOWP_COMPUTE, + CUBLAS_GEMM_DEFAULT));*/ + cublasCheck(cublasHgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, @@ -1053,7 +1154,7 @@ void attention_forward5(float* out, lowp_float* vaccum, lowp_float* qkvr, lowp_f // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side num_blocks = ceil_div(B * T * C, block_size); if(!skip_permute || first_run_validation) { - unpermute_kernel_lowp<<>>((lowp_float*)vaccum, out, B, T, NH, HS); + unpermute_kernel_lowp<<>>(vaccum, out, B, T, NH, HS); } } @@ -1251,6 +1352,7 @@ int main(int argc, char **argv) { float* out = (float*)malloc(B * T * C * sizeof(float)); float* preatt = (float*)malloc(B * NH * T * T * sizeof(float)); float* att = (float*)malloc(B * NH * T * T * sizeof(float)); + //float* inp = make_random_float(B * T * 3 * C, 10.0f); float* inp = make_random_float(B * T * 3 * C); // move to GPU From 6ed6e99f33c38e49aee0c9276256aee6596b4a0c Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 22 Apr 2024 00:07:46 +0100 Subject: [PATCH 3/6] Fix FP16/BF16 kernel version issues --- train_gpt2.cu | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/train_gpt2.cu b/train_gpt2.cu index c99bd2e09..196dadebf 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -849,6 +849,15 @@ void matmul_forward_cublaslt(float* out, cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogueBias, sizeof(epilogueBias))); cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); + //cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, B*T, C, &alpha, weight, C, inp, C, &beta, out, OC)); + //m=OC + //n=B*T + //k=C + + //cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, m, k, m)); + //cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, k, n, k)); + //cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUDA_R_32F, m, n, m)); + // define matrix layouts cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, C, OC, C)); cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, C, B*T, C)); From 929ad2fd0fc652738960698ab7fe72362068c509 Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 22 Apr 2024 12:45:55 +0100 Subject: [PATCH 4/6] Use -DENABLE_CUDNN for cuDNN path instead, and enable TF32 for attention_forward.cu by default --- dev/cuda/attention_forward.cu | 148 +++++++++------------------------- 1 file changed, 37 insertions(+), 111 deletions(-) diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index 6db2bbb89..dd4565290 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -6,7 +6,7 @@ You need cuDNN from: https://developer.nvidia.com/cudnn And the cuDNN front-end from: https://github.com/NVIDIA/cudnn-frontend/tree/main Compile example: -nvcc -I/path/to/cudnn-frontend/include -O3 --use_fast_math -lcublas -lcudnn attention_forward.cu -o attention_forward +nvcc -I/path/to/cudnn-frontend/include -DENABLE_CUDNN -O3 --use_fast_math -lcublas -lcudnn attention_forward.cu -o attention_forward version 1 is naive port from CPU code to kernel, parallelize over batch, time, heads only ./attention_forward 1 @@ -40,7 +40,7 @@ https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md version 11 is kernel 10 skipping FP16/FP32 conversions (requires fully FP16 network) ./attention_forward 11 */ -//#define ENABLE_CUDNN +//#define ENABLE_CUDNN // can be enabled via nvcc "-DENABLE_CUDNN" #include #include #include @@ -52,15 +52,15 @@ version 11 is kernel 10 skipping FP16/FP32 conversions (requires fully FP16 netw #include #include "common.h" -typedef half lowp_float; // or __nv_bfloat16 -#define CUBLAS_LOWP CUDA_R_16F +// Class that wraps __nv_fp8_e4m3 and returns __nv_fp8_storage_t on demand +typedef __nv_bfloat16 lowp_float; // or __nv_bfloat16 +#define CUBLAS_LOWP CUDA_R_16BF #define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F -//#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F // 32F compute is as fast as 16F on A100/H100 #ifdef ENABLE_CUDNN #include namespace fe = cudnn_frontend; -#define CUDNN_16BIT fe::DataType_t::HALF // or BFLOAT16 (cuDNN kernels only) +#define CUDNN_16BIT fe::DataType_t::BFLOAT16 // or BFLOAT16 (cuDNN kernels only) static cudnnHandle_t cudnn_handle; static size_t cudnn_workspace_size = 32 * 1024 * 1024; @@ -862,20 +862,6 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f const float alpha = 1.0f; const float beta = 0.0f; - /*cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, - CUBLAS_OP_T, CUBLAS_OP_N, - T, T, HS, - &alpha, - k, CUDA_R_32F, HS, T * HS, - q, CUDA_R_32F, HS, T * HS, - &beta, - preatt, CUDA_R_32F, T, T * T, - B * NH, - CUBLAS_COMPUTE_32F, - CUBLAS_GEMM_DEFAULT));*/ - - //cudaCheck(cudaMemset(preatt, 0, B * NH * T * T * sizeof(float))); - cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, @@ -894,18 +880,6 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f // new approach: first cuBLAS another batched matmul // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) -/*cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - HS, T, T, - &alpha, - v, CUDA_R_32F, HS, T * HS, - att, CUDA_R_32F, T, T * T, - &beta, - vaccum, CUDA_R_32F, HS, T * HS, - B * NH, - CUBLAS_COMPUTE_32F, - CUBLAS_GEMM_DEFAULT));*/ - cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, @@ -961,7 +935,7 @@ __global__ void softmax_forward_kernel5_lowp(lowp_float* out, float inv_temperat if(4*pos_by_4 + warp.thread_rank() <= own_pos) { float old_maxval = maxval; - maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]); + maxval = fmaxf(maxval, (float)x[4*pos_by_4 + warp.thread_rank()]); sumval *= expf(inv_temperature * (old_maxval - maxval)); sumval += expf(inv_temperature * ((float)x[4*pos_by_4 + warp.thread_rank()] - maxval)); } @@ -1049,78 +1023,30 @@ void attention_forward5(float* out, lowp_float* vaccum, lowp_float* qkvr, lowp_f permute_kernel_lowp<<>>(q, k, v, inp, B, T, NH, HS); } -/* -cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, - const void *A, - cudaDataType_t Atype, - int lda, - long long int strideA, - const void *B, - cudaDataType_t Btype, - int ldb, - long long int strideB, - const void *beta, - void *C, - cudaDataType_t Ctype, - int ldc, - long long int strideC, - int batchCount, - cublasComputeType_t computeType, - cublasGemmAlgo_t algo) - -cublasStatus_t cublasHgemmStridedBatched(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, int n, int k, - const __half *alpha, - const __half *A, int lda, - long long int strideA, - const __half *B, int ldb, - long long int strideB, - const __half *beta, - __half *C, int ldc, - long long int strideC, - int batchCount) -*/ + // IMPORTANT: alpha/beta are FP32 for CUBLAS_COMPUTE_32F even if FP16 inputs/outputs + // But need FP16 scale for CUBLAS_COMPUTE_16F (no errors if you get it wrong *sigh*) + const float alpha = 1.0f; + const float beta = 0.0f; + const lowp_float alpha_lowp = (lowp_float)alpha; + const lowp_float beta_lowp = (lowp_float)beta; + void* alpha_ptr = CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F ? (void*)&alpha_lowp : (void*)α + void* beta_ptr = CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F ? (void*)&beta_lowp : (void*)β // batched matrix multiply with cuBLAS - const lowp_float alpha = (lowp_float)1.0f; - const lowp_float beta = (lowp_float)0.0f; - - /*cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, + cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, - &alpha, + alpha_ptr, k, CUBLAS_LOWP, HS, T * HS, q, CUBLAS_LOWP, HS, T * HS, - &beta, + beta_ptr, preatt, CUBLAS_LOWP, T, T * T, B * NH, CUBLAS_LOWP_COMPUTE, - CUBLAS_GEMM_DEFAULT));*/ - - // memset preatt - things don't break as much as they should... - //cudaCheck(cudaMemset(preatt, 0, B * NH * T * T * sizeof(lowp_float))); - - cublasCheck(cublasHgemmStridedBatched(cublas_handle, - CUBLAS_OP_T, CUBLAS_OP_N, - T, T, HS, - &alpha, - k, HS, T * HS, - q, HS, T * HS, - &beta, - preatt, T, T * T, - B * NH)); - + CUBLAS_GEMM_DEFAULT)); // multiply all elements of preatt elementwise by scale - lowp_float scale = 1.0 / sqrtf(HS); + float scale = 1.0f / sqrtf(HS); int softmax_block_size = 256; int grid_size = ceil_div(B * NH * T * 32, softmax_block_size); softmax_forward_kernel5_lowp<<>>(att, scale, preatt, B * NH, T); @@ -1128,27 +1054,17 @@ cublasStatus_t cublasHgemmStridedBatched(cublasHandle_t handle, // new approach: first cuBLAS another batched matmul // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) - /*cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, + cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, - &alpha, + alpha_ptr, v, CUBLAS_LOWP, HS, T * HS, att, CUBLAS_LOWP, T, T * T, - &beta, + beta_ptr, vaccum, CUBLAS_LOWP, HS, T * HS, B * NH, CUBLAS_LOWP_COMPUTE, - CUBLAS_GEMM_DEFAULT));*/ - - cublasCheck(cublasHgemmStridedBatched(cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - HS, T, T, - &alpha, - v, HS, T * HS, - att, T, T * T, - &beta, - vaccum, HS, T * HS, - B * NH)); + CUBLAS_GEMM_DEFAULT)); // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side @@ -1342,7 +1258,16 @@ int main(int argc, char **argv) { int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, deviceIdx); + + // setup cuBLAS (and cuDNN if needed) cublasCreate(&cublas_handle); + int enable_tf32 = deviceProp.major >= 8 ? 1 : 0; + printf("enable_tf32: %d\n", enable_tf32); + cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); + #ifdef ENABLE_CUDNN checkCudnnErr(cudnnCreate(&cudnn_handle)); cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); @@ -1387,18 +1312,19 @@ int main(int argc, char **argv) { printf("Checking block size %d.\n", block_size); attention_forward(kernel_num, d_out, d_stats, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size); // all kernels should produce the correct output out - validate_result(d_out, out, "out", B * T * C, 1e-3f); + // todo - make accuracy threshold dynamic and depend on FP16 vs FP32? + validate_result(d_out, out, "out", B * T * C, 1e-2f); // but as for preatt and att, things get a bit more complicated: if (kernel_num != 2 && kernel_num < 5) { // kernel 2 (knowingly) fails att/preatt because it uses a different algorithm // that estimates the softmax online and never materializes preatt/att - validate_result(d_att, att, "att", B * NH * T * T, 1e-3f); + validate_result(d_att, att, "att", B * NH * T * T, 1e-2f); } if (kernel_num != 2 && kernel_num < 4) { // kernel 4 (knowingly) fails preatt because it fuses the scale normalization // into the softmax, so preatt is off by 1.0f / sqrt(HS) // but att and out (checked below) should match. - validate_result(d_preatt, preatt, "preatt", B * NH * T * T, 1e-3f); + validate_result(d_preatt, preatt, "preatt", B * NH * T * T, 1e-2f); } } printf("All results match. Starting benchmarks.\n\n"); From 33a31a7815cb41cf6aebdcc87385bacd4d67a4ce Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 22 Apr 2024 13:19:19 +0100 Subject: [PATCH 5/6] extra comments + tiny fix --- dev/cuda/attention_forward.cu | 52 ++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index dd4565290..359be4cc4 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -52,28 +52,35 @@ version 11 is kernel 10 skipping FP16/FP32 conversions (requires fully FP16 netw #include #include "common.h" -// Class that wraps __nv_fp8_e4m3 and returns __nv_fp8_storage_t on demand -typedef __nv_bfloat16 lowp_float; // or __nv_bfloat16 -#define CUBLAS_LOWP CUDA_R_16BF +// ---------------------------------------------------------------------------- +// Floating point precision setup +typedef __nv_bfloat16 lowp_float; // half or __nv_bfloat16 (or float) +#define CUBLAS_LOWP CUDA_R_16BF // CUDA_R_16F or CUDA_R_16BF (or CUDA_R_32F) +// CUBLAS_COMPUTE_32F or CUBLAS_COMPUTE_16F (for CUDA_R_16F only, potentially slower?!) #define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F +// ---------------------------------------------------------------------------- +// CUDA & cuDNN setup +static cublasHandle_t cublas_handle; +static bool first_run_validation = true; // always run e.g. permute on 1st run + #ifdef ENABLE_CUDNN #include namespace fe = cudnn_frontend; -#define CUDNN_16BIT fe::DataType_t::BFLOAT16 // or BFLOAT16 (cuDNN kernels only) +#if CUBLAS_LOWP == CUDA_R_16BF +#define CUDNN_16BIT fe::DataType_t::BFLOAT16 +#else +#define CUDNN_16BIT fe::DataType_t::HALF +#endif static cudnnHandle_t cudnn_handle; -static size_t cudnn_workspace_size = 32 * 1024 * 1024; +static size_t cudnn_workspace_size = 32 * 1024 * 1024; // TODO is this only for backward? static void* cudnn_workspace = NULL; #define checkCudaErr(err) assert((int)err == 0); #define checkCudnnErr(err) assert((int)err == 0); #endif // ENABLE_CUDNN // ---------------------------------------------------------------------------- -// CUDA setup -static cublasHandle_t cublas_handle; -static bool first_run_validation = true; // always run e.g. permute on 1st run -// ---------------------------------------------------------------------------- // CPU code reference void attention_forward_cpu(float* out, float* preatt, float* att, @@ -1024,7 +1031,7 @@ void attention_forward5(float* out, lowp_float* vaccum, lowp_float* qkvr, lowp_f } // IMPORTANT: alpha/beta are FP32 for CUBLAS_COMPUTE_32F even if FP16 inputs/outputs - // But need FP16 scale for CUBLAS_COMPUTE_16F (no errors if you get it wrong *sigh*) + // But need FP16 scale for CUBLAS_COMPUTE_16F (no errors otherwise, just garbage results *sigh*) const float alpha = 1.0f; const float beta = 0.0f; const lowp_float alpha_lowp = (lowp_float)alpha; @@ -1053,7 +1060,6 @@ void attention_forward5(float* out, lowp_float* vaccum, lowp_float* qkvr, lowp_f // new approach: first cuBLAS another batched matmul // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) - cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, @@ -1082,8 +1088,11 @@ using graph_and_tensors = std::tuple, std::shared_ptr, // Attn_scale, std::shared_ptr, // O std::shared_ptr>; // Stats + +// Need a cache because graph->build_operation_graph() is slow but everything else seems fast using cache_type = std::unordered_map; +// Loosely based on cuDNN frontend samples functions and massively simplified template auto lookup_cache_or_build_graph(Args... args) { static cache_type user_maintained_cache; @@ -1094,7 +1103,7 @@ auto lookup_cache_or_build_graph(Args... args) { .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - // (B, N, 3, NH, d) + // QKV is (B, N, 3, NH, d) which cuDNN can handle directly without an external permute auto Q = graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_qkv, d}) @@ -1121,6 +1130,7 @@ auto lookup_cache_or_build_graph(Args... args) { auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options); + // Output is (B, N, NH, d) BF16/FP16 and stats for backward pass is (B, NH, N) FP32 O->set_output(true).set_dim({b, h, s_qkv, d}).set_stride({h * d * s_qkv, d, h * d, 1}); assert(stats == nullptr || is_inference == false); if (!is_inference) { @@ -1164,11 +1174,11 @@ void attention_forward10(float* out, // output: (B, T, NH, HS) bool is_inference = stats != NULL; float attn_scale_cpu = 1.0 / sqrtf(HS); - const int block_size = 64; + // Optionally convert from FP32 to FP16/BF16 (always on 1st run to get correct results) + const int block_size = 64; // smallest full occupancy block size on modern GPUs int total_threads = B * T * C * 3; assert(total_threads % block_size == 0); int num_blocks = total_threads / block_size; - if (!skip_conversion || first_run_validation) { fp32_to_lowp_kernel<<>>(qkvr, inp); } @@ -1181,16 +1191,17 @@ void attention_forward10(float* out, // output: (B, T, NH, HS) void* devPtrK = (qkvr + NH * HS); void* devPtrV = (qkvr + 2 * NH * HS); void* devPtrO = (void*)vaccum; - std::unordered_map, void*> variant_pack = { {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}}; if (is_inference == false) { variant_pack[softmax_stats] = (void*)stats; } - assert(graph->get_workspace_size() <= cudnn_workspace_size); + // Execute graph + assert(graph->get_workspace_size() <= cudnn_workspace_size); // TODO - not needed for forward? assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); + // Optionally convert back from FP16/BF16 to FP32 total_threads = B * T * C; assert(total_threads % block_size == 0); num_blocks = total_threads / block_size; @@ -1304,6 +1315,9 @@ int main(int argc, char **argv) { } printf("Using kernel %d\n", kernel_num); int block_sizes[] = {32, 64, 128, 256, 512}; + + // Lower accuracy requirements for FP16 (1e-4f also too much for TF32 on kernels 3 & 4) + float accuracy_threshold = (kernel_num <= 4) ? 1e-3f : 1e-2f; // first check the correctness of the kernel attention_forward_cpu(out, preatt, att, inp, B, T, C, NH); @@ -1313,18 +1327,18 @@ int main(int argc, char **argv) { attention_forward(kernel_num, d_out, d_stats, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size); // all kernels should produce the correct output out // todo - make accuracy threshold dynamic and depend on FP16 vs FP32? - validate_result(d_out, out, "out", B * T * C, 1e-2f); + validate_result(d_out, out, "out", B * T * C, accuracy_threshold); // but as for preatt and att, things get a bit more complicated: if (kernel_num != 2 && kernel_num < 5) { // kernel 2 (knowingly) fails att/preatt because it uses a different algorithm // that estimates the softmax online and never materializes preatt/att - validate_result(d_att, att, "att", B * NH * T * T, 1e-2f); + validate_result(d_att, att, "att", B * NH * T * T, accuracy_threshold); } if (kernel_num != 2 && kernel_num < 4) { // kernel 4 (knowingly) fails preatt because it fuses the scale normalization // into the softmax, so preatt is off by 1.0f / sqrt(HS) // but att and out (checked below) should match. - validate_result(d_preatt, preatt, "preatt", B * NH * T * T, 1e-2f); + validate_result(d_preatt, preatt, "preatt", B * NH * T * T, accuracy_threshold); } } printf("All results match. Starting benchmarks.\n\n"); From 2c47bd42bad450f347c867cbd9657605ef4d285e Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 22 Apr 2024 13:24:55 +0100 Subject: [PATCH 6/6] remove unintentional train_gpu2.cu change --- train_gpt2.cu | 9 --------- 1 file changed, 9 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 196dadebf..c99bd2e09 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -849,15 +849,6 @@ void matmul_forward_cublaslt(float* out, cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogueBias, sizeof(epilogueBias))); cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); - //cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, B*T, C, &alpha, weight, C, inp, C, &beta, out, OC)); - //m=OC - //n=B*T - //k=C - - //cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, m, k, m)); - //cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, k, n, k)); - //cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUDA_R_32F, m, n, m)); - // define matrix layouts cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, C, OC, C)); cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, C, B*T, C));