Skip to content

Commit

Permalink
fix unrestricted threads and dynamic shared memory bug
Browse files Browse the repository at this point in the history
  • Loading branch information
nomadlx committed Aug 4, 2021
1 parent c94c0fb commit 8b36048
Showing 1 changed file with 49 additions and 37 deletions.
86 changes: 49 additions & 37 deletions lightseq/training/csrc/kernels/embedding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,56 @@ get tokens position in sequences that the padding tokens are ignored.
@thread
gridDim.x = batch_size
gridDim.y = 1
blockDim.x = block_x_tokens
blockDim.y = block_y_tokens
blockDim.x = min(seq_len, MAX_THREADS)
blockDim.y = 1
@param
output: [batch_size, seq_len]
output: [batch_size, seq_len, 2]
input: [batch_size, seq_len]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
padding_idx: padding index of the sentences (default: 2)
*/
__global__ void get_tokens_position(
int *output, const int *input, int seq_len, int padding_idx) {
int *output, const int *input, int batch_size, int seq_len, int padding_idx) {
int batch_id = blockIdx.x;
int seq_id = threadIdx.x * blockDim.y + threadIdx.y;
if (seq_id >= seq_len) return;
int start_seq_id = threadIdx.x;
int threads = blockDim.x;

int target_pos = batch_id * seq_len + seq_id;
int batch_offset = batch_id * seq_len;
int temp_offset[2];
temp_offset[0] = 0;
temp_offset[1] = batch_size * seq_len;

extern __shared__ int temp[];
int *temp = output;
int pout = 0, pin = 1;
int pout_idx = batch_id * seq_len * 2 + pout * seq_len + seq_id;
int pin_idx = 0;
temp[pout_idx] = (seq_id > 0 && input[target_pos - 1] != padding_idx) ? 1 : 0;
int pout_idx, pin_idx, target_pos;

for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) {
target_pos = batch_offset + seq_id;
pout_idx = temp_offset[pout] + batch_offset + seq_id;
temp[pout_idx] = (seq_id > 0 && input[target_pos - 1] != padding_idx) ? 1 : 0;
}
__syncthreads();
for (int stride = 1; stride < seq_len; stride *= 2) {
pout = 1 - pout;
pin = 1 - pout;
pout_idx = batch_id * seq_len * 2 + pout * seq_len + seq_id;
pin_idx = batch_id * seq_len * 2 + pin * seq_len + seq_id;
if (seq_id >= stride)
temp[pout_idx] = temp[pin_idx] + temp[pin_idx - stride];
else
temp[pout_idx] = temp[pin_idx];
for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) {
pout_idx = temp_offset[pout] + batch_offset + seq_id;
pin_idx = temp_offset[pin] + batch_offset + seq_id;
if (seq_id >= stride)
temp[pout_idx] = temp[pin_idx] + temp[pin_idx - stride];
else
temp[pout_idx] = temp[pin_idx];
}
__syncthreads();
}

output[target_pos] = temp[pout_idx];

for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) {
target_pos = batch_offset + seq_id;
pout_idx = temp_offset[pout] + batch_offset + seq_id;
output[target_pos] = temp[pout_idx];
}
}

/**
Expand Down Expand Up @@ -227,16 +241,14 @@ void launch_lookup_scale_pos_dropout<float>(
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
int step, cudaStream_t &stream) {
int *tokens_position;
cudaMalloc(&tokens_position, batch_size * seq_len * sizeof(int));
int block_x_tokens = min(seq_len, MAX_THREADS);
int block_y_tokens = (seq_len + block_x_tokens - 1) / block_x_tokens;
int dynamic_temp_size = batch_size * seq_len * 2 * sizeof(int);
cudaMalloc(&tokens_position, batch_size * seq_len * 2 * sizeof(int));
int p_threads = min(seq_len, MAX_THREADS);
dim3 p_grid_dim(batch_size, 1);
dim3 p_block_dim(block_x_tokens, block_y_tokens);

// get the position index of the tokens alone, because synchronization is required at the sequence level
get_tokens_position<<<p_grid_dim, p_block_dim, dynamic_temp_size, stream>>>(
tokens_position, input, seq_len, padding_idx);
dim3 p_block_dim(p_threads, 1);
// get the position index of the tokens alone,
// because synchronization is required at the sequence level
get_tokens_position<<<p_grid_dim, p_block_dim, 0, stream>>>(
tokens_position, input, batch_size, seq_len, padding_idx);

float emb_scale = sqrt(embedding_dim);
embedding_dim >>= 2;
Expand All @@ -253,6 +265,7 @@ void launch_lookup_scale_pos_dropout<float>(
output, input, tokens_position, embeddings, pos_embeddings, dropout_mask,
seq_len, embedding_dim, padding_idx, dropout_ratio, emb_scale, step,
seed);

cudaFree(tokens_position);
}

Expand All @@ -263,16 +276,14 @@ void launch_lookup_scale_pos_dropout<__half>(
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
int step, cudaStream_t &stream) {
int *tokens_position;
cudaMalloc(&tokens_position, batch_size * seq_len * sizeof(int));
int block_x_tokens = min(seq_len, MAX_THREADS);
int block_y_tokens = (seq_len + block_x_tokens - 1) / block_x_tokens;
int dynamic_temp_size = batch_size * seq_len * 2 * sizeof(int);
cudaMalloc(&tokens_position, batch_size * seq_len * 2 * sizeof(int));
int p_threads = min(seq_len, MAX_THREADS);
dim3 p_grid_dim(batch_size, 1);
dim3 p_block_dim(block_x_tokens, block_y_tokens);

// get the position index of the tokens alone, because synchronization is required at the sequence level
get_tokens_position<<<p_grid_dim, p_block_dim, dynamic_temp_size, stream>>>(
tokens_position, input, seq_len, padding_idx);
dim3 p_block_dim(p_threads, 1);
// get the position index of the tokens alone,
// because synchronization is required at the sequence level
get_tokens_position<<<p_grid_dim, p_block_dim, 0, stream>>>(
tokens_position, input, batch_size, seq_len, padding_idx);

float emb_scale = sqrt(embedding_dim);
embedding_dim >>= 3;
Expand All @@ -289,6 +300,7 @@ void launch_lookup_scale_pos_dropout<__half>(
output, input, tokens_position, embeddings, pos_embeddings, dropout_mask,
seq_len, embedding_dim, padding_idx, dropout_ratio, emb_scale, step,
seed);

cudaFree(tokens_position);
}

Expand Down

0 comments on commit 8b36048

Please sign in to comment.