From e587020b674af048de61b2bfb04584a59f332010 Mon Sep 17 00:00:00 2001 From: sstamenk Date: Mon, 3 Nov 2025 21:53:43 +0100 Subject: [PATCH] Fix int32 overflow for blocksize quantization --- csrc/kernels.hip | 28 ++++++++++++++------------ csrc/ops.hip | 52 ++++++++++++++++++++++++------------------------ 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index bef6cffa6..fdeab46f2 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -348,16 +348,17 @@ template 0) ? NUM_PER_TH / 2 : NUM_PER_TH]; - T vals[NUM_PER_TH]; - float rand_vals[NUM_PER_TH]; - unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; - //float local_abs_max = -FLT_MAX; - float local_abs_max = 0.0f; - int local_rand_idx = 0; + float local_abs_max = 0.0f; + int local_rand_idx = 0; typedef hipcub::BlockLoad LoadT; typedef hipcub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; @@ -375,9 +376,9 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float for(int i = threadIdx.x; i < 256; i+=blockDim.x) smem_code[i] = code[i]; - for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { - valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + + for (int64_t i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = min(BLOCK_SIZE, static_cast(n - i)); local_abs_max = -FLT_MAX; __syncthreads(); @@ -465,7 +466,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs { if (DATA_TYPE > 0) { - valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); + // Cast n to int64_t to avoid overflow for large n + valid_items_load = min(TILE_SIZE, static_cast((static_cast(n) + 1) / 2) - i); valid_items_store = min(TILE_SIZE * 2, n - i * 2); } else diff --git a/csrc/ops.hip b/csrc/ops.hip index b26d138e1..2fe68f9bd 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -34,7 +34,7 @@ void quantize(float *code, float *A, unsigned char *out, int n) { int num_blocks = n/1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); + hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); } @@ -72,21 +72,21 @@ template void quantizeBlockwise(floa template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, hipStream_t stream) { - int num_blocks = n/blocksize; - num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + // Upcast to int64 to avoid overflow for large n + int grid_blocks = ((int64_t)n + tile_size - 1) / tile_size; + if(DATA_TYPE > 0) - hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize/2, n); + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3(grid_blocks), dim3(64), 0, stream, code, A, absmax, out, blocksize / 2, n); else - hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize, n); + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3(grid_blocks), dim3(64), 0, stream, code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); } - template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, @@ -102,10 +102,10 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); } - hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); break; case MOMENTUM: @@ -114,22 +114,22 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); } - hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); break; case LION: // in lion, the momentum update after the parameter update - hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); } break; @@ -156,9 +156,9 @@ template void optimizerStatic8bit(T* p, T* g, case ADAM: CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); - hipLaunchKernelGGL(( kOptimizerStatic8bit2State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + hipLaunchKernelGGL(( kOptimizerStatic8bit2State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); break; @@ -166,20 +166,20 @@ template void optimizerStatic8bit(T* p, T* g, case RMSPROP: case ADAGRAD: CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); - hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); break; case LION: // in lion, the momentum update happens after the parameter update - hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); break; default: @@ -221,7 +221,7 @@ template void optimizerStatic8bitBlockwise( case ADEMAMIX: num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, + hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); break; @@ -231,7 +231,7 @@ template void optimizerStatic8bitBlockwise( case LION: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, + hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); break; @@ -245,7 +245,7 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, int num_blocks = n/2048; num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPercentileClipping), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n); + hipLaunchKernelGGL(( kPercentileClipping), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); } @@ -669,7 +669,7 @@ void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_va template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { - hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); CUDA_CHECK_RETURN(hipPeekAtLastError()); } @@ -679,9 +679,9 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out int num_blocks = (m+31)/32; if(bits == 32) - hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(32), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(32), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) - hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -689,7 +689,7 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi int num_blocks = (m+31)/32; - hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0, m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0, m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream) @@ -712,7 +712,7 @@ template void func(T *A, T *B, T value, long n) int blocks = n/threads; blocks = n % threads == 0 ? blocks : blocks + 1; blocks = blocks > 65535 ? 65535 : blocks; - hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); + hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); }