Skip to content

Commit

Permalink
undo unnecessary modification and limit the maximum value of position…
Browse files Browse the repository at this point in the history
…al parameters
  • Loading branch information
nomadlx committed Aug 3, 2021
1 parent 63b78f5 commit 5dc1265
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
5 changes: 4 additions & 1 deletion examples/training/fairseq/fs_modules/ls_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def build_model(cls, args, task):

@classmethod
def build_embedding(cls, args, dictionary, embed_dim, max_positions, **kwargs):
max_positions = max_positions + dictionary.pad() + 1
config = LSTransformerEmbeddingLayer.get_config(
vocab_size=len(dictionary),
embedding_dim=embed_dim,
Expand Down Expand Up @@ -356,6 +355,10 @@ def tiny_architecture(args):

@register_model_architecture("ls_transformer", "ls_transformer")
def base_architecture(args):
# specify a small value (300) which meet the needs of most NLP datasets, to avoid OOM error
args.max_source_positions = min(MAX_SEQ_LENGTH, getattr(args, "max_source_positions", MAX_SEQ_LENGTH))
args.max_target_positions = min(MAX_SEQ_LENGTH, getattr(args, "max_target_positions", MAX_SEQ_LENGTH))

args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
Expand Down
6 changes: 2 additions & 4 deletions lightseq/training/csrc/kernels/embedding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ __global__ void lookup_scale_pos_dropout<float>(
int offset = i - target_pos * embedding_dim;
float4 e4 = embeddings4[tid * embedding_dim + offset];
// step is non-zero only in inference
// position numbers begin at padding_idx+1, same to fairseq implementation
float4 pe4 = pos_embeddings4[(token_pos_id + step + padding_idx + 1) * embedding_dim + offset];
float4 pe4 = pos_embeddings4[(token_pos_id + step) * embedding_dim + offset];
float4 res4;
res4.x = (emb_scale * e4.x + pe4.x) * scale * m[0];
res4.y = (emb_scale * e4.y + pe4.y) * scale * m[1];
Expand Down Expand Up @@ -196,8 +195,7 @@ __global__ void lookup_scale_pos_dropout<__half>(
int offset = i - target_pos * embedding_dim;
float4 e4 = embeddings4[tid * embedding_dim + offset];
// step is non-zero only in inference
// position numbers begin at padding_idx+1, same to fairseq
float4 pe4 = pos_embeddings4[(token_pos_id + step + padding_idx + 1) * embedding_dim + offset];
float4 pe4 = pos_embeddings4[(token_pos_id + step) * embedding_dim + offset];
float4 res4;

__half2 *e_h2 = reinterpret_cast<__half2 *>(&e4);
Expand Down
2 changes: 0 additions & 2 deletions lightseq/training/ops/pytorch/transformer_embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ def get_pos_embedding(self, num_pos_embeddings):
)
if self.config.embedding_dim % 2 == 1:
emb = torch.cat([emb, torch.zeros(num_pos_embeddings, 1)], dim=1)
if self.config.padding_idx is not None:
emb[self.config.padding_idx, :] = 0
return emb

def __assign_layer_weight_grad(self):
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def build_extension(self, ext):
include_package_data=True,
entry_points={
"console_scripts": [
"lightseq-train = examples.training.fairseq.fs_cli."
"lightseq-train = lightseq.examples.training.fairseq.fs_cli."
"lightseq_fairseq_train_cli:ls_cli_main",
"lightseq-generate = examples.training.fairseq.fs_cli."
"lightseq-generate = lightseq.examples.training.fairseq.fs_cli."
"lightseq_fairseq_generate_cli:ls_cli_main",
"lightseq-validate = examples.training.fairseq.fs_cli."
"lightseq-validate = lightseq.examples.training.fairseq.fs_cli."
"lightseq_fairseq_validate_cli:ls_cli_main",
],
},
Expand Down

0 comments on commit 5dc1265

Please sign in to comment.