Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions csrc/transformer/softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@

namespace cg = cooperative_groups;

dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads)
{
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
// Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// The batch size is typically relatively small, while the sequence length could potentially be
// arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit.
unsigned x = heads * sequence_length / block_compute_size;
unsigned y = batch_size;
return {x, y};
}

// Fused attention + softmax
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(float* vals,
Expand All @@ -22,12 +36,12 @@ __global__ void attn_softmax(float* vals,
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);

int batch = blockIdx.x;
int row = blockIdx.y;
int batch = blockIdx.y;
int row = blockIdx.x;
int max_threads_in_sequence = std::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;

int data_offset = batch * (gridDim.y * block_width) + row * block_width +
int data_offset = batch * (gridDim.x * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;

Expand Down Expand Up @@ -153,12 +167,12 @@ __global__ void attn_softmax(__half* vals,
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);

int batch = blockIdx.x;
int row = blockIdx.y;
int batch = blockIdx.y;
int row = blockIdx.x;
int max_threads_in_sequence = std::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;

int data_offset = batch * (gridDim.y * block_width) + row * block_width +
int data_offset = batch * (gridDim.x * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;

Expand Down Expand Up @@ -300,9 +314,7 @@ void launch_attn_softmax<float>(float* vals,
const int threads = 128;
int seq_length4 = sequence_length / 4;

int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);

int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;

Expand Down Expand Up @@ -333,10 +345,7 @@ void launch_attn_softmax<float>(float* vals,
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);

int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;

Expand Down Expand Up @@ -370,9 +379,7 @@ void launch_attn_softmax<__half>(__half* vals,
const int threads = 128;
int seq_length4 = sequence_length / 4;

int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);

int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;

Expand Down Expand Up @@ -404,10 +411,7 @@ void launch_attn_softmax<__half>(__half* vals,
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);

int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;

Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_cuda_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
(8,128,128,2,3,True,True),
(8,4096,128,64,3,True,True),
(8,8192,128,64,3,False,True),
(1,256,2048,32,3,True,True),
]) # yapf: disable
def test_forward(batch_size,
hidden_size,
Expand Down