From 8b3604869be3be2726312605c5d4a5917e0b3ef0 Mon Sep 17 00:00:00 2001 From: nomadlx Date: Wed, 4 Aug 2021 11:54:29 +0800 Subject: [PATCH] fix unrestricted threads and dynamic shared memory bug --- .../csrc/kernels/embedding_kernels.cu | 86 +++++++++++-------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/lightseq/training/csrc/kernels/embedding_kernels.cu b/lightseq/training/csrc/kernels/embedding_kernels.cu index 604d4b378..a17ca531d 100644 --- a/lightseq/training/csrc/kernels/embedding_kernels.cu +++ b/lightseq/training/csrc/kernels/embedding_kernels.cu @@ -10,42 +10,56 @@ get tokens position in sequences that the padding tokens are ignored. @thread gridDim.x = batch_size gridDim.y = 1 -blockDim.x = block_x_tokens -blockDim.y = block_y_tokens +blockDim.x = min(seq_len, MAX_THREADS) +blockDim.y = 1 @param -output: [batch_size, seq_len] +output: [batch_size, seq_len, 2] input: [batch_size, seq_len] +batch_size: the size of the current batch seq_len: the sequence length of the current batch padding_idx: padding index of the sentences (default: 2) */ __global__ void get_tokens_position( - int *output, const int *input, int seq_len, int padding_idx) { + int *output, const int *input, int batch_size, int seq_len, int padding_idx) { int batch_id = blockIdx.x; - int seq_id = threadIdx.x * blockDim.y + threadIdx.y; - if (seq_id >= seq_len) return; + int start_seq_id = threadIdx.x; + int threads = blockDim.x; - int target_pos = batch_id * seq_len + seq_id; + int batch_offset = batch_id * seq_len; + int temp_offset[2]; + temp_offset[0] = 0; + temp_offset[1] = batch_size * seq_len; - extern __shared__ int temp[]; + int *temp = output; int pout = 0, pin = 1; - int pout_idx = batch_id * seq_len * 2 + pout * seq_len + seq_id; - int pin_idx = 0; - temp[pout_idx] = (seq_id > 0 && input[target_pos - 1] != padding_idx) ? 1 : 0; + int pout_idx, pin_idx, target_pos; + + for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) { + target_pos = batch_offset + seq_id; + pout_idx = temp_offset[pout] + batch_offset + seq_id; + temp[pout_idx] = (seq_id > 0 && input[target_pos - 1] != padding_idx) ? 1 : 0; + } __syncthreads(); for (int stride = 1; stride < seq_len; stride *= 2) { pout = 1 - pout; pin = 1 - pout; - pout_idx = batch_id * seq_len * 2 + pout * seq_len + seq_id; - pin_idx = batch_id * seq_len * 2 + pin * seq_len + seq_id; - if (seq_id >= stride) - temp[pout_idx] = temp[pin_idx] + temp[pin_idx - stride]; - else - temp[pout_idx] = temp[pin_idx]; + for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) { + pout_idx = temp_offset[pout] + batch_offset + seq_id; + pin_idx = temp_offset[pin] + batch_offset + seq_id; + if (seq_id >= stride) + temp[pout_idx] = temp[pin_idx] + temp[pin_idx - stride]; + else + temp[pout_idx] = temp[pin_idx]; + } __syncthreads(); } - - output[target_pos] = temp[pout_idx]; + + for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) { + target_pos = batch_offset + seq_id; + pout_idx = temp_offset[pout] + batch_offset + seq_id; + output[target_pos] = temp[pout_idx]; + } } /** @@ -227,16 +241,14 @@ void launch_lookup_scale_pos_dropout( int seq_len, int embedding_dim, int padding_idx, float dropout_ratio, int step, cudaStream_t &stream) { int *tokens_position; - cudaMalloc(&tokens_position, batch_size * seq_len * sizeof(int)); - int block_x_tokens = min(seq_len, MAX_THREADS); - int block_y_tokens = (seq_len + block_x_tokens - 1) / block_x_tokens; - int dynamic_temp_size = batch_size * seq_len * 2 * sizeof(int); + cudaMalloc(&tokens_position, batch_size * seq_len * 2 * sizeof(int)); + int p_threads = min(seq_len, MAX_THREADS); dim3 p_grid_dim(batch_size, 1); - dim3 p_block_dim(block_x_tokens, block_y_tokens); - - // get the position index of the tokens alone, because synchronization is required at the sequence level - get_tokens_position<<>>( - tokens_position, input, seq_len, padding_idx); + dim3 p_block_dim(p_threads, 1); + // get the position index of the tokens alone, + // because synchronization is required at the sequence level + get_tokens_position<<>>( + tokens_position, input, batch_size, seq_len, padding_idx); float emb_scale = sqrt(embedding_dim); embedding_dim >>= 2; @@ -253,6 +265,7 @@ void launch_lookup_scale_pos_dropout( output, input, tokens_position, embeddings, pos_embeddings, dropout_mask, seq_len, embedding_dim, padding_idx, dropout_ratio, emb_scale, step, seed); + cudaFree(tokens_position); } @@ -263,16 +276,14 @@ void launch_lookup_scale_pos_dropout<__half>( int seq_len, int embedding_dim, int padding_idx, float dropout_ratio, int step, cudaStream_t &stream) { int *tokens_position; - cudaMalloc(&tokens_position, batch_size * seq_len * sizeof(int)); - int block_x_tokens = min(seq_len, MAX_THREADS); - int block_y_tokens = (seq_len + block_x_tokens - 1) / block_x_tokens; - int dynamic_temp_size = batch_size * seq_len * 2 * sizeof(int); + cudaMalloc(&tokens_position, batch_size * seq_len * 2 * sizeof(int)); + int p_threads = min(seq_len, MAX_THREADS); dim3 p_grid_dim(batch_size, 1); - dim3 p_block_dim(block_x_tokens, block_y_tokens); - - // get the position index of the tokens alone, because synchronization is required at the sequence level - get_tokens_position<<>>( - tokens_position, input, seq_len, padding_idx); + dim3 p_block_dim(p_threads, 1); + // get the position index of the tokens alone, + // because synchronization is required at the sequence level + get_tokens_position<<>>( + tokens_position, input, batch_size, seq_len, padding_idx); float emb_scale = sqrt(embedding_dim); embedding_dim >>= 3; @@ -289,6 +300,7 @@ void launch_lookup_scale_pos_dropout<__half>( output, input, tokens_position, embeddings, pos_embeddings, dropout_mask, seq_len, embedding_dim, padding_idx, dropout_ratio, emb_scale, step, seed); + cudaFree(tokens_position); }