Skip to content

Commit

Permalink
fix pos embedding index bug
Browse files Browse the repository at this point in the history
  • Loading branch information
nomadlx committed Aug 2, 2021
1 parent 79249ad commit 4224b87
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 21 deletions.
2 changes: 1 addition & 1 deletion examples/training/fairseq/fs_modules/ls_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def build_embedding(cls, args, dictionary, embed_dim, max_positions, **kwargs):
vocab_size=len(dictionary),
embedding_dim=embed_dim,
max_batch_tokens=args.max_tokens,
max_seq_len=MAX_SEQ_LENGTH, # FIXME later
max_seq_len=max_positions,
padding_idx=dictionary.pad(),
dropout=args.dropout,
fp16=args.fp16,
Expand Down
119 changes: 99 additions & 20 deletions lightseq/training/csrc/kernels/embedding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,51 @@

#include "kernels.h"

/**
@brief: get_tokens_position
get tokens position in sequences that the padding tokens are ignored.
@thread
gridDim.x = batch_size
gridDim.y = 1
blockDim.x = block_x_tokens
blockDim.y = block_y_tokens
@param
output: [batch_size, seq_len]
input: [batch_size, seq_len]
seq_len: the sequence length of the current batch
padding_idx: padding index of the sentences (default: 2)
*/
__global__ void get_tokens_position(
int *output, const int *input, int seq_len, int padding_idx) {
int batch_id = blockIdx.x;
int seq_id = threadIdx.x * blockDim.y + threadIdx.y;
if (seq_id >= seq_len) return;

int target_pos = batch_id * seq_len + seq_id;

extern __shared__ int temp[];
int pout = 0, pin = 1;
int pout_idx = batch_id * seq_len * 2 + pout * seq_len + seq_id;
int pin_idx = 0;
temp[pout_idx] = (seq_id > 0 && input[target_pos - 1] != padding_idx) ? 1 : 0;
__syncthreads();
for (int stride = 1; stride < seq_len; stride *= 2) {
pout = 1 - pout;
pin = 1 - pout;
pout_idx = batch_id * seq_len * 2 + pout * seq_len + seq_id;
pin_idx = batch_id * seq_len * 2 + pin * seq_len + seq_id;
if (seq_id >= stride)
temp[pout_idx] = temp[pin_idx] + temp[pin_idx - stride];
else
temp[pout_idx] = temp[pin_idx];
__syncthreads();
}

output[target_pos] = temp[pout_idx];
}

/**
@brief: lookup_scale_pos_dropout
forward of embedding layer in fairseq, including
Expand All @@ -15,8 +60,9 @@ blockDim.x = tokens_per_block
blockDim.y = min(embedding_dim, MAX_THREADS)
@param
input: [batch_size, seq_len]
output: [batch_size, seq_len, embedding_dim]
input: [batch_size, seq_len]
tokens_position: [batch_size, seq_len]
embeddings: [vocab_size, embedding_dim]
pos_embeddings: [max_seq_len, embedding_dim]
dropout_mask: [batch_size, seq_len, embedding_dim]
Expand All @@ -29,16 +75,17 @@ training and valid)
*/
template <typename T>
__global__ void lookup_scale_pos_dropout(
T *output, const int *input, const T *embeddings, const T *pos_embeddings,
uint8_t *dropout_mask, int seq_len, int embedding_dim, int padding_idx,
float dropout_ratio, float emb_scale, int step, int seed);
T *output, const int *input, const int *tokens_position,
const T *embeddings, const T *pos_embeddings, uint8_t *dropout_mask,
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
float emb_scale, int step, int seed);

template <>
__global__ void lookup_scale_pos_dropout<float>(
float *output, const int *input, const float *embeddings,
const float *pos_embeddings, uint8_t *dropout_mask, int seq_len,
int embedding_dim, int padding_idx, float dropout_ratio, float emb_scale,
int step, int seed) {
float *output, const int *input, const int *tokens_position,
const float *embeddings, const float *pos_embeddings, uint8_t *dropout_mask,
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
float emb_scale, int step, int seed) {
int batch_id = blockIdx.x;
int seq_id = blockIdx.y * blockDim.x + threadIdx.x;
if (seq_id >= seq_len) return;
Expand All @@ -48,6 +95,8 @@ __global__ void lookup_scale_pos_dropout<float>(
int end = (target_pos + 1) * embedding_dim;
int tid = input[target_pos];

int token_pos_id = tokens_position[target_pos];

float4 *output4 = reinterpret_cast<float4 *>(output);
const float4 *embeddings4 = reinterpret_cast<const float4 *>(embeddings);
const float4 *pos_embeddings4 =
Expand Down Expand Up @@ -81,7 +130,8 @@ __global__ void lookup_scale_pos_dropout<float>(
int offset = i - target_pos * embedding_dim;
float4 e4 = embeddings4[tid * embedding_dim + offset];
// step is non-zero only in inference
float4 pe4 = pos_embeddings4[(seq_id + step) * embedding_dim + offset];
// position numbers begin at padding_idx+1, same to fairseq implementation
float4 pe4 = pos_embeddings4[(token_pos_id + step + padding_idx + 1) * embedding_dim + offset];
float4 res4;
res4.x = (emb_scale * e4.x + pe4.x) * scale * m[0];
res4.y = (emb_scale * e4.y + pe4.y) * scale * m[1];
Expand All @@ -93,10 +143,10 @@ __global__ void lookup_scale_pos_dropout<float>(

template <>
__global__ void lookup_scale_pos_dropout<__half>(
__half *output, const int *input, const __half *embeddings,
const __half *pos_embeddings, uint8_t *dropout_mask, int seq_len,
int embedding_dim, int padding_idx, float dropout_ratio, float emb_scale,
int step, int seed) {
__half *output, const int *input, const int *tokens_position,
const __half *embeddings, const __half *pos_embeddings,
uint8_t *dropout_mask, int seq_len, int embedding_dim, int padding_idx,
float dropout_ratio, float emb_scale, int step, int seed) {
int batch_id = blockIdx.x;
int seq_id = blockIdx.y * blockDim.x + threadIdx.x;
if (seq_id >= seq_len) return;
Expand All @@ -106,6 +156,8 @@ __global__ void lookup_scale_pos_dropout<__half>(
int end = (target_pos + 1) * embedding_dim;
int tid = input[target_pos];

int token_pos_id = tokens_position[target_pos];

float4 *output4 = reinterpret_cast<float4 *>(output);
const float4 *embeddings4 = reinterpret_cast<const float4 *>(embeddings);
const float4 *pos_embeddings4 =
Expand Down Expand Up @@ -144,7 +196,8 @@ __global__ void lookup_scale_pos_dropout<__half>(
int offset = i - target_pos * embedding_dim;
float4 e4 = embeddings4[tid * embedding_dim + offset];
// step is non-zero only in inference
float4 pe4 = pos_embeddings4[(seq_id + step) * embedding_dim + offset];
// position numbers begin at padding_idx+1, same to fairseq
float4 pe4 = pos_embeddings4[(token_pos_id + step + padding_idx + 1) * embedding_dim + offset];
float4 res4;

__half2 *e_h2 = reinterpret_cast<__half2 *>(&e4);
Expand Down Expand Up @@ -175,6 +228,18 @@ void launch_lookup_scale_pos_dropout<float>(
const float *pos_embeddings, uint8_t *dropout_mask, int batch_size,
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
int step, cudaStream_t &stream) {
int *tokens_position;
cudaMalloc(&tokens_position, batch_size * seq_len * sizeof(int));
int block_x_tokens = min(seq_len, MAX_THREADS);
int block_y_tokens = (seq_len + block_x_tokens - 1) / block_x_tokens;
int dynamic_temp_size = batch_size * seq_len * 2 * sizeof(int);
dim3 p_grid_dim(batch_size, 1);
dim3 p_block_dim(block_x_tokens, block_y_tokens);

// get the position index of the tokens alone, because synchronization is required at the sequence level
get_tokens_position<<<p_grid_dim, p_block_dim, dynamic_temp_size, stream>>>(
tokens_position, input, seq_len, padding_idx);

float emb_scale = sqrt(embedding_dim);
embedding_dim >>= 2;

Expand All @@ -186,10 +251,11 @@ void launch_lookup_scale_pos_dropout<float>(
int seed = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();

lookup_scale_pos_dropout<float><<<grid_dim, block_dim, 0, stream>>>(
output, input, embeddings, pos_embeddings, dropout_mask, seq_len,
embedding_dim, padding_idx, dropout_ratio, emb_scale, step, seed);
output, input, tokens_position, embeddings, pos_embeddings, dropout_mask,
seq_len, embedding_dim, padding_idx, dropout_ratio, emb_scale, step,
seed);
cudaFree(tokens_position);
}

template <>
Expand All @@ -198,6 +264,18 @@ void launch_lookup_scale_pos_dropout<__half>(
const __half *pos_embeddings, uint8_t *dropout_mask, int batch_size,
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
int step, cudaStream_t &stream) {
int *tokens_position;
cudaMalloc(&tokens_position, batch_size * seq_len * sizeof(int));
int block_x_tokens = min(seq_len, MAX_THREADS);
int block_y_tokens = (seq_len + block_x_tokens - 1) / block_x_tokens;
int dynamic_temp_size = batch_size * seq_len * 2 * sizeof(int);
dim3 p_grid_dim(batch_size, 1);
dim3 p_block_dim(block_x_tokens, block_y_tokens);

// get the position index of the tokens alone, because synchronization is required at the sequence level
get_tokens_position<<<p_grid_dim, p_block_dim, dynamic_temp_size, stream>>>(
tokens_position, input, seq_len, padding_idx);

float emb_scale = sqrt(embedding_dim);
embedding_dim >>= 3;

Expand All @@ -209,10 +287,11 @@ void launch_lookup_scale_pos_dropout<__half>(
int seed = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();

lookup_scale_pos_dropout<__half><<<grid_dim, block_dim, 0, stream>>>(
output, input, embeddings, pos_embeddings, dropout_mask, seq_len,
embedding_dim, padding_idx, dropout_ratio, emb_scale, step, seed);
output, input, tokens_position, embeddings, pos_embeddings, dropout_mask,
seq_len, embedding_dim, padding_idx, dropout_ratio, emb_scale, step,
seed);
cudaFree(tokens_position);
}

/**
Expand Down

0 comments on commit 4224b87

Please sign in to comment.