Skip to content

Commit

Permalink
fix pos embedding index bug (#144)
Browse files Browse the repository at this point in the history
* fix pos embedding index bug of left padding

* add left padding embedding unit test
  • Loading branch information
nomadlx committed Aug 4, 2021
1 parent 79249ad commit 913b999
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 27 deletions.
10 changes: 9 additions & 1 deletion examples/training/fairseq/fs_modules/ls_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def build_embedding(cls, args, dictionary, embed_dim, max_positions, **kwargs):
vocab_size=len(dictionary),
embedding_dim=embed_dim,
max_batch_tokens=args.max_tokens,
max_seq_len=MAX_SEQ_LENGTH, # FIXME later
max_seq_len=max_positions,
padding_idx=dictionary.pad(),
dropout=args.dropout,
fp16=args.fp16,
Expand Down Expand Up @@ -355,6 +355,14 @@ def tiny_architecture(args):

@register_model_architecture("ls_transformer", "ls_transformer")
def base_architecture(args):
# specify a small value (300) which meet the needs of most NLP datasets, to avoid OOM error
args.max_source_positions = min(
MAX_SEQ_LENGTH, getattr(args, "max_source_positions", MAX_SEQ_LENGTH)
)
args.max_target_positions = min(
MAX_SEQ_LENGTH, getattr(args, "max_target_positions", MAX_SEQ_LENGTH)
)

args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
Expand Down
133 changes: 113 additions & 20 deletions lightseq/training/csrc/kernels/embedding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,67 @@

#include "kernels.h"

/**
@brief: get_tokens_position
get tokens position in sequences that the padding tokens are ignored.
@thread
gridDim.x = batch_size
gridDim.y = 1
blockDim.x = min(seq_len, MAX_THREADS)
blockDim.y = 1
@param
output: [batch_size, seq_len, 2]
input: [batch_size, seq_len]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
padding_idx: padding index of the sentences (default: 2)
*/
__global__ void get_tokens_position(int *output, const int *input,
int batch_size, int seq_len,
int padding_idx) {
int batch_id = blockIdx.x;
int start_seq_id = threadIdx.x;
int threads = blockDim.x;

int batch_offset = batch_id * seq_len;
int temp_offset[2];
temp_offset[0] = 0;
temp_offset[1] = batch_size * seq_len;

int *temp = output;
int pout = 0, pin = 1;
int pout_idx, pin_idx, target_pos;

for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) {
target_pos = batch_offset + seq_id;
pout_idx = temp_offset[pout] + batch_offset + seq_id;
temp[pout_idx] =
(seq_id > 0 && input[target_pos - 1] != padding_idx) ? 1 : 0;
}
__syncthreads();
for (int stride = 1; stride < seq_len; stride *= 2) {
pout = 1 - pout;
pin = 1 - pout;
for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) {
pout_idx = temp_offset[pout] + batch_offset + seq_id;
pin_idx = temp_offset[pin] + batch_offset + seq_id;
if (seq_id >= stride)
temp[pout_idx] = temp[pin_idx] + temp[pin_idx - stride];
else
temp[pout_idx] = temp[pin_idx];
}
__syncthreads();
}

for (int seq_id = start_seq_id; seq_id < seq_len; seq_id += threads) {
target_pos = batch_offset + seq_id;
pout_idx = temp_offset[pout] + batch_offset + seq_id;
output[target_pos] = temp[pout_idx];
}
}

/**
@brief: lookup_scale_pos_dropout
forward of embedding layer in fairseq, including
Expand All @@ -15,8 +76,9 @@ blockDim.x = tokens_per_block
blockDim.y = min(embedding_dim, MAX_THREADS)
@param
input: [batch_size, seq_len]
output: [batch_size, seq_len, embedding_dim]
input: [batch_size, seq_len]
tokens_position: [batch_size, seq_len]
embeddings: [vocab_size, embedding_dim]
pos_embeddings: [max_seq_len, embedding_dim]
dropout_mask: [batch_size, seq_len, embedding_dim]
Expand All @@ -29,16 +91,17 @@ training and valid)
*/
template <typename T>
__global__ void lookup_scale_pos_dropout(
T *output, const int *input, const T *embeddings, const T *pos_embeddings,
uint8_t *dropout_mask, int seq_len, int embedding_dim, int padding_idx,
float dropout_ratio, float emb_scale, int step, int seed);
T *output, const int *input, const int *tokens_position,
const T *embeddings, const T *pos_embeddings, uint8_t *dropout_mask,
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
float emb_scale, int step, int seed);

template <>
__global__ void lookup_scale_pos_dropout<float>(
float *output, const int *input, const float *embeddings,
const float *pos_embeddings, uint8_t *dropout_mask, int seq_len,
int embedding_dim, int padding_idx, float dropout_ratio, float emb_scale,
int step, int seed) {
float *output, const int *input, const int *tokens_position,
const float *embeddings, const float *pos_embeddings, uint8_t *dropout_mask,
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
float emb_scale, int step, int seed) {
int batch_id = blockIdx.x;
int seq_id = blockIdx.y * blockDim.x + threadIdx.x;
if (seq_id >= seq_len) return;
Expand All @@ -48,6 +111,8 @@ __global__ void lookup_scale_pos_dropout<float>(
int end = (target_pos + 1) * embedding_dim;
int tid = input[target_pos];

int token_pos_id = tokens_position[target_pos];

float4 *output4 = reinterpret_cast<float4 *>(output);
const float4 *embeddings4 = reinterpret_cast<const float4 *>(embeddings);
const float4 *pos_embeddings4 =
Expand Down Expand Up @@ -81,7 +146,8 @@ __global__ void lookup_scale_pos_dropout<float>(
int offset = i - target_pos * embedding_dim;
float4 e4 = embeddings4[tid * embedding_dim + offset];
// step is non-zero only in inference
float4 pe4 = pos_embeddings4[(seq_id + step) * embedding_dim + offset];
float4 pe4 =
pos_embeddings4[(token_pos_id + step) * embedding_dim + offset];
float4 res4;
res4.x = (emb_scale * e4.x + pe4.x) * scale * m[0];
res4.y = (emb_scale * e4.y + pe4.y) * scale * m[1];
Expand All @@ -93,10 +159,10 @@ __global__ void lookup_scale_pos_dropout<float>(

template <>
__global__ void lookup_scale_pos_dropout<__half>(
__half *output, const int *input, const __half *embeddings,
const __half *pos_embeddings, uint8_t *dropout_mask, int seq_len,
int embedding_dim, int padding_idx, float dropout_ratio, float emb_scale,
int step, int seed) {
__half *output, const int *input, const int *tokens_position,
const __half *embeddings, const __half *pos_embeddings,
uint8_t *dropout_mask, int seq_len, int embedding_dim, int padding_idx,
float dropout_ratio, float emb_scale, int step, int seed) {
int batch_id = blockIdx.x;
int seq_id = blockIdx.y * blockDim.x + threadIdx.x;
if (seq_id >= seq_len) return;
Expand All @@ -106,6 +172,8 @@ __global__ void lookup_scale_pos_dropout<__half>(
int end = (target_pos + 1) * embedding_dim;
int tid = input[target_pos];

int token_pos_id = tokens_position[target_pos];

float4 *output4 = reinterpret_cast<float4 *>(output);
const float4 *embeddings4 = reinterpret_cast<const float4 *>(embeddings);
const float4 *pos_embeddings4 =
Expand Down Expand Up @@ -144,7 +212,8 @@ __global__ void lookup_scale_pos_dropout<__half>(
int offset = i - target_pos * embedding_dim;
float4 e4 = embeddings4[tid * embedding_dim + offset];
// step is non-zero only in inference
float4 pe4 = pos_embeddings4[(seq_id + step) * embedding_dim + offset];
float4 pe4 =
pos_embeddings4[(token_pos_id + step) * embedding_dim + offset];
float4 res4;

__half2 *e_h2 = reinterpret_cast<__half2 *>(&e4);
Expand Down Expand Up @@ -175,6 +244,16 @@ void launch_lookup_scale_pos_dropout<float>(
const float *pos_embeddings, uint8_t *dropout_mask, int batch_size,
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
int step, cudaStream_t &stream) {
int *tokens_position;
cudaMalloc(&tokens_position, batch_size * seq_len * 2 * sizeof(int));
int p_threads = min(seq_len, MAX_THREADS);
dim3 p_grid_dim(batch_size, 1);
dim3 p_block_dim(p_threads, 1);
// get the position index of the tokens alone,
// because synchronization is required at the sequence level
get_tokens_position<<<p_grid_dim, p_block_dim, 0, stream>>>(
tokens_position, input, batch_size, seq_len, padding_idx);

float emb_scale = sqrt(embedding_dim);
embedding_dim >>= 2;

Expand All @@ -186,10 +265,12 @@ void launch_lookup_scale_pos_dropout<float>(
int seed = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();

lookup_scale_pos_dropout<float><<<grid_dim, block_dim, 0, stream>>>(
output, input, embeddings, pos_embeddings, dropout_mask, seq_len,
embedding_dim, padding_idx, dropout_ratio, emb_scale, step, seed);
output, input, tokens_position, embeddings, pos_embeddings, dropout_mask,
seq_len, embedding_dim, padding_idx, dropout_ratio, emb_scale, step,
seed);

cudaFree(tokens_position);
}

template <>
Expand All @@ -198,6 +279,16 @@ void launch_lookup_scale_pos_dropout<__half>(
const __half *pos_embeddings, uint8_t *dropout_mask, int batch_size,
int seq_len, int embedding_dim, int padding_idx, float dropout_ratio,
int step, cudaStream_t &stream) {
int *tokens_position;
cudaMalloc(&tokens_position, batch_size * seq_len * 2 * sizeof(int));
int p_threads = min(seq_len, MAX_THREADS);
dim3 p_grid_dim(batch_size, 1);
dim3 p_block_dim(p_threads, 1);
// get the position index of the tokens alone,
// because synchronization is required at the sequence level
get_tokens_position<<<p_grid_dim, p_block_dim, 0, stream>>>(
tokens_position, input, batch_size, seq_len, padding_idx);

float emb_scale = sqrt(embedding_dim);
embedding_dim >>= 3;

Expand All @@ -209,10 +300,12 @@ void launch_lookup_scale_pos_dropout<__half>(
int seed = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();

lookup_scale_pos_dropout<__half><<<grid_dim, block_dim, 0, stream>>>(
output, input, embeddings, pos_embeddings, dropout_mask, seq_len,
embedding_dim, padding_idx, dropout_ratio, emb_scale, step, seed);
output, input, tokens_position, embeddings, pos_embeddings, dropout_mask,
seq_len, embedding_dim, padding_idx, dropout_ratio, emb_scale, step,
seed);

cudaFree(tokens_position);
}

/**
Expand Down
8 changes: 4 additions & 4 deletions tests/fairseq_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,9 @@ def get_embedding(
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
return emb

def make_positions(self, tensor):
mask = torch.ones_like(tensor)
return (torch.cumsum(mask, dim=1).type_as(mask) - 1).long()
def make_positions(self, tensor, padding_idx):
mask = tensor.ne(padding_idx).int()
return ((torch.cumsum(mask, dim=1).type_as(mask) - 1) * mask).long()

def forward(
self,
Expand All @@ -662,7 +662,7 @@ def forward(
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.size(0), input.size(1)
positions = self.make_positions(input)
positions = self.make_positions(input, self.padding_idx)
mask = (
torch.ne(input, self.padding_idx)
.unsqueeze(2)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_ls_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,11 @@ def test_embedding_layer_forward():
# TODO: can not generate PAD in the middle of the sentences.
config = ls_emb_config_fp16
input = kt.randint(config.padding_idx + 1, config.vocab_size, (batch_size, seq_len))
input = input * (1 - padding_mask) + config.padding_idx * padding_mask
pad_left = random.choice([True, False])
if pad_left:
input = input * padding_mask + config.padding_idx * (1 - padding_mask)
else:
input = input * (1 - padding_mask) + config.padding_idx * padding_mask

if kt.dtype == torch.float:
custom_layer = custom_emb_layer_fp32
Expand Down Expand Up @@ -817,7 +821,11 @@ def test_embedding_layer_backward():
padding_mask = kt.attn_mask(batch_size, seq_len, dtype=torch.int)
config = ls_emb_config_fp16
input = kt.randint(config.padding_idx + 1, config.vocab_size, (batch_size, seq_len))
input = input * (1 - padding_mask) + config.padding_idx * padding_mask
pad_left = random.choice([True, False])
if pad_left:
input = input * padding_mask + config.padding_idx * (1 - padding_mask)
else:
input = input * (1 - padding_mask) + config.padding_idx * padding_mask

if kt.dtype == torch.float:
custom_layer = custom_emb_layer_fp32
Expand Down

0 comments on commit 913b999

Please sign in to comment.