diff --git a/csrc/transformer/softmax_kernels.cu b/csrc/transformer/softmax_kernels.cu index be776b0c074d..4c0d987e617a 100644 --- a/csrc/transformer/softmax_kernels.cu +++ b/csrc/transformer/softmax_kernels.cu @@ -4,6 +4,20 @@ namespace cg = cooperative_groups; +dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads) +{ + int seq_length4 = sequence_length / 4; + int block_compute_size = + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); + // Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications + // The batch size is typically relatively small, while the sequence length could potentially be + // arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit. + unsigned x = heads * sequence_length / block_compute_size; + unsigned y = batch_size; + return {x, y}; +} + // Fused attention + softmax template __global__ void attn_softmax(float* vals, @@ -22,12 +36,12 @@ __global__ void attn_softmax(float* vals, cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); - int batch = blockIdx.x; - int row = blockIdx.y; + int batch = blockIdx.y; + int row = blockIdx.x; int max_threads_in_sequence = std::max(seq_length, tbSeq); int seq_lane = threadIdx.x % max_threads_in_sequence; - int data_offset = batch * (gridDim.y * block_width) + row * block_width + + int data_offset = batch * (gridDim.x * block_width) + row * block_width + (threadIdx.x / max_threads_in_sequence) * seq_length; int mask_offset = batch * seq_length; @@ -153,12 +167,12 @@ __global__ void attn_softmax(__half* vals, cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); - int batch = blockIdx.x; - int row = blockIdx.y; + int batch = blockIdx.y; + int row = blockIdx.x; int max_threads_in_sequence = std::max(seq_length, tbSeq); int seq_lane = threadIdx.x % max_threads_in_sequence; - int data_offset = batch * (gridDim.y * block_width) + row * block_width + + int data_offset = batch * (gridDim.x * block_width) + row * block_width + (threadIdx.x / max_threads_in_sequence) * seq_length; int mask_offset = batch * seq_length; @@ -300,9 +314,7 @@ void launch_attn_softmax(float* vals, const int threads = 128; int seq_length4 = sequence_length / 4; - int block_compute_size = - (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); - dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size); + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; @@ -333,10 +345,7 @@ void launch_attn_softmax(float* vals, <<>>(vals, attn_mask, heads, seq_length4, iterations); else { const int threads = 256; - block_compute_size = - (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) - : 1); - dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size); + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; @@ -370,9 +379,7 @@ void launch_attn_softmax<__half>(__half* vals, const int threads = 128; int seq_length4 = sequence_length / 4; - int block_compute_size = - (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); - dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size); + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; @@ -404,10 +411,7 @@ void launch_attn_softmax<__half>(__half* vals, <<>>(vals, attn_mask, heads, seq_length4, iterations); else { const int threads = 256; - block_compute_size = - (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) - : 1); - dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size); + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index 8559f6bb5319..5231d2cdecf5 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -232,6 +232,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): (8,128,128,2,3,True,True), (8,4096,128,64,3,True,True), (8,8192,128,64,3,False,True), + (1,256,2048,32,3,True,True), ]) # yapf: disable def test_forward(batch_size, hidden_size,