From 5dc12650b413f16faf82f9d2c4a1fbc6c232fd88 Mon Sep 17 00:00:00 2001 From: nomadlx Date: Tue, 3 Aug 2021 18:49:56 +0800 Subject: [PATCH] undo unnecessary modification and limit the maximum value of positional parameters --- examples/training/fairseq/fs_modules/ls_transformer.py | 5 ++++- lightseq/training/csrc/kernels/embedding_kernels.cu | 6 ++---- .../training/ops/pytorch/transformer_embedding_layer.py | 2 -- setup.py | 6 +++--- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/training/fairseq/fs_modules/ls_transformer.py b/examples/training/fairseq/fs_modules/ls_transformer.py index 1412d8662..3003e64de 100644 --- a/examples/training/fairseq/fs_modules/ls_transformer.py +++ b/examples/training/fairseq/fs_modules/ls_transformer.py @@ -146,7 +146,6 @@ def build_model(cls, args, task): @classmethod def build_embedding(cls, args, dictionary, embed_dim, max_positions, **kwargs): - max_positions = max_positions + dictionary.pad() + 1 config = LSTransformerEmbeddingLayer.get_config( vocab_size=len(dictionary), embedding_dim=embed_dim, @@ -356,6 +355,10 @@ def tiny_architecture(args): @register_model_architecture("ls_transformer", "ls_transformer") def base_architecture(args): + # specify a small value (300) which meet the needs of most NLP datasets, to avoid OOM error + args.max_source_positions = min(MAX_SEQ_LENGTH, getattr(args, "max_source_positions", MAX_SEQ_LENGTH)) + args.max_target_positions = min(MAX_SEQ_LENGTH, getattr(args, "max_target_positions", MAX_SEQ_LENGTH)) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) diff --git a/lightseq/training/csrc/kernels/embedding_kernels.cu b/lightseq/training/csrc/kernels/embedding_kernels.cu index c6ccc69fe..604d4b378 100644 --- a/lightseq/training/csrc/kernels/embedding_kernels.cu +++ b/lightseq/training/csrc/kernels/embedding_kernels.cu @@ -130,8 +130,7 @@ __global__ void lookup_scale_pos_dropout( int offset = i - target_pos * embedding_dim; float4 e4 = embeddings4[tid * embedding_dim + offset]; // step is non-zero only in inference - // position numbers begin at padding_idx+1, same to fairseq implementation - float4 pe4 = pos_embeddings4[(token_pos_id + step + padding_idx + 1) * embedding_dim + offset]; + float4 pe4 = pos_embeddings4[(token_pos_id + step) * embedding_dim + offset]; float4 res4; res4.x = (emb_scale * e4.x + pe4.x) * scale * m[0]; res4.y = (emb_scale * e4.y + pe4.y) * scale * m[1]; @@ -196,8 +195,7 @@ __global__ void lookup_scale_pos_dropout<__half>( int offset = i - target_pos * embedding_dim; float4 e4 = embeddings4[tid * embedding_dim + offset]; // step is non-zero only in inference - // position numbers begin at padding_idx+1, same to fairseq - float4 pe4 = pos_embeddings4[(token_pos_id + step + padding_idx + 1) * embedding_dim + offset]; + float4 pe4 = pos_embeddings4[(token_pos_id + step) * embedding_dim + offset]; float4 res4; __half2 *e_h2 = reinterpret_cast<__half2 *>(&e4); diff --git a/lightseq/training/ops/pytorch/transformer_embedding_layer.py b/lightseq/training/ops/pytorch/transformer_embedding_layer.py index 603e14943..e80bfae3e 100644 --- a/lightseq/training/ops/pytorch/transformer_embedding_layer.py +++ b/lightseq/training/ops/pytorch/transformer_embedding_layer.py @@ -149,8 +149,6 @@ def get_pos_embedding(self, num_pos_embeddings): ) if self.config.embedding_dim % 2 == 1: emb = torch.cat([emb, torch.zeros(num_pos_embeddings, 1)], dim=1) - if self.config.padding_idx is not None: - emb[self.config.padding_idx, :] = 0 return emb def __assign_layer_weight_grad(self): diff --git a/setup.py b/setup.py index 011287be9..8538e552d 100644 --- a/setup.py +++ b/setup.py @@ -116,11 +116,11 @@ def build_extension(self, ext): include_package_data=True, entry_points={ "console_scripts": [ - "lightseq-train = examples.training.fairseq.fs_cli." + "lightseq-train = lightseq.examples.training.fairseq.fs_cli." "lightseq_fairseq_train_cli:ls_cli_main", - "lightseq-generate = examples.training.fairseq.fs_cli." + "lightseq-generate = lightseq.examples.training.fairseq.fs_cli." "lightseq_fairseq_generate_cli:ls_cli_main", - "lightseq-validate = examples.training.fairseq.fs_cli." + "lightseq-validate = lightseq.examples.training.fairseq.fs_cli." "lightseq_fairseq_validate_cli:ls_cli_main", ], },