Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Failed to train the Mask-predict with larger model/hidden dimension #7

Open
alphadl opened this issue Jan 21, 2020 · 9 comments
Open

Comments

@alphadl
Copy link

alphadl commented Jan 21, 2020

Elegant work! In addition to training a transformer_base-scale model, I am still trying to train a large model, (e.g., 1024 model dim. & 4096 hidden dim), such that I can fine-tune Mask-predict with XLM.

However, when I simply change the dimension and fix other arguments, the training is failed, that is, the ppl is even becoming bigger. Can you give me some advices?

Below is my training command:

python train.py data-bin/xlm_pretained-wmt14.en-de --arch bert_transformer_seq2seq --share-all-embeddings --criterion label_smoothed_length_cross_entropy --label-smoothing 0.1 --lr 5e-4 --warmup-init-lr 1e-7 --min-lr 1e-9 --lr-scheduler inverse_sqrt --warmup-updates 10000 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-6 --task translation_self --max-tokens 11000 --weight-decay 0.01 --dropout 0.3 --encoder-layers 6 --encoder-embed-dim 1024 --decoder-layers 6 --decoder-embed-dim 1024 --encoder-attention-heads 8 --decoder-attention-heads 8 --max-source-positions 10000 --max-target-positions 10000 --max-update 300000 --seed 0 --save-dir ${model_dir} --update-freq 3 --ddp-backend=no_c10d --fp16 --keep-last-epochs 10

and the following is the log of one training step:

| epoch 012:  74%|▋| 814/1099 [24:26<08:28,  1.79s/it, loss=12.243, nll_loss=11.121, ppl=2226.58, wps=33332, ups=1, wpb=60068.756, bsz=4060.299, num_updates=12894, lr=0.000440328, gnorm=0.341, clip=0.000, oom=0.000, loss_scale=0.250, wall=23856, train_wall=20393, length_loss=6.6472] 

BTW, because I reused the XLM vocabulary list, the vocab size of larger Mask-predict is more than 60k+.

Namespace(adam_betas='(0.9,0.999)', adam_eps=1e-06, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='bert_transformer_seq2seq', attention_dropout=0.0, best_checkpoint_metric='loss', bilm_add_bos=False, bilm_attention_dropout=0.0, bilm_mask_last_state=False, bilm_model_dropout=0.1, bilm_relu_dropout=0.0, bucket_cap_mb=25, clip_norm=25, cpu=False, criterion='label_smoothed_length_cross_entropy', curriculum=0, data=['data-bin/xlm_pretained-wmt14.en-de'], dataset_impl=None, ddp_backend='no_c10d', decoder_attention_heads=8, decoder_embed_dim=1024, decoder_embed_path=None, decoder_embed_scale=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layers=6, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method='tcp://localhost:10859', distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=4, dropout=0.3, dynamic_length=False, embedding_only=False, encoder_attention_heads=8, encoder_embed_dim=1024, encoder_embed_path=None, encoder_embed_scale=None, encoder_ffn_embed_dim=4096, encoder_layers=6, encoder_learned_pos=False, encoder_normalize_before=False, find_unused_parameters=False, fix_batches_to_gpus=False, fp16=True, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, keep_interval_updates=-1, keep_last_epochs=10, label_smoothing=0.1, left_pad_source='True', left_pad_target='False', log_format=None, log_interval=1000, lr=[0.0005], lr_scheduler='inverse_sqrt', mask_range=False, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=10000, max_target_positions=10000, max_tokens=11000, max_tokens_valid=11000, max_update=500000, maximize_best_checkpoint_metric=False, memory_efficient_fp16=False, min_loss_scale=0.0001, min_lr=1e-09, no_dec_token_positional_embeddings=False, no_enc_token_positional_embeddings=False, no_epoch_checkpoints=False, no_last_checkpoints=False, no_progress_bar=False, no_save=False, no_save_optimizer_state=False, num_workers=0, optimizer='adam', optimizer_overrides='{}', raw_text=False, relu_dropout=0.0, required_batch_size_multiple=8, reset_dataloader=False, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='./distill_model_from_scratch_1024_xlm', save_interval=1, save_interval_updates=0, seed=0, self_target=False, sentence_avg=False, share_all_embeddings=True, share_decoder_input_output_embed=False, skip_invalid_size_inputs_valid_test=False, source_lang=None, target_lang=None, task='translation_self', tbmf_wrapper=False, tensorboard_logdir='', threshold_loss_scale=None, train_subset='train', update_freq=[3], upsample_primary=1, use_bmuf=False, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=10000, weight_decay=0.01)
| [en] dictionary: 60192 types
| [de] dictionary: 60192 types
| data-bin/xlm_pretained-wmt14.en-de valid 3000 examples
Transformer_nonautoregressive(
  (encoder): TransformerEncoder(
    (embed_tokens): Embedding(60192, 1024, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(10002, 1024, padding_idx=1)
    (embed_lengths): Embedding(10000, 1024)
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (layer_norms): ModuleList(
          (0): BertLayerNorm()
          (1): BertLayerNorm()
        )
      )(1)(2)...(5)
        )
      )
    )
  )
  (decoder): SelfTransformerDecoder(
    (embed_tokens): Embedding(60192, 1024, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(10002, 1024, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (self_attn_layer_norm): BertLayerNorm()
        (encoder_attn): MultiheadAttention(
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (encoder_attn_layer_norm): BertLayerNorm()
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): BertLayerNorm()
      )(1)(2)...(5)
   )
)
@alphadl
Copy link
Author

alphadl commented Jan 21, 2020

@yinhanliu Hoping for your advices 😉

@jungokasai
Copy link

jungokasai commented Jan 21, 2020

Have you tried using the hyperparameters for transformer big? https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md#3-train-a-model. I think I ran into a similar problem at some point and switching to the hyperparameters from Ott et al. 2018 large got it work.

@Marjan-GH
Copy link

Also, don't forget to preprocess your data with the code of this branch.

@alphadl
Copy link
Author

alphadl commented Jan 22, 2020

@jungokasai @Marjan-GH Thanks! I adjusted the original command to:

python train.py data-bin/xlm_pretained-wmt14.en-de --arch bert_transformer_seq2seq --share-all-embeddings --criterion label_smoothed_length_cross_entropy --label-smoothing 0.1 --lr 5e-4 --warmup-init-lr 1e-7 --min-lr 1e-9 --lr-scheduler inverse_sqrt --warmup-updates 4000 --optimizer adam --adam-betas '(0.9,0.98)' --task translation_self --max-tokens 11000 --weight-decay 0.01 --dropout 0.3 --encoder-layers 6 --encoder-embed-dim 1024 --decoder-layers 6 --decoder-embed-dim 1024 --encoder-attention-heads 8 --decoder-attention-heads 8 --max-source-positions 10000 --max-target-positions 10000 --max-update 500000 --seed 0 --save-dir ${model_dir} --update-freq 3 --ddp-backend=no_c10d --fp16 --keep-last-epochs 10

And seemingly this works for me, I picked one line of the training log. The ppl converges in a reasonable trend:

| epoch 006:  86%|▊| 944/1099 [28:55<04:45,  1.84s/it, loss=3.986, nll_loss=1.915, ppl=3.77, wps=32894, ups=1, wpb=60193.316, bsz=4095.215, num_updates=6429, lr=0.000394392, gnorm=0.705, clip=0.000, oom=0.000, loss_scale=0.125, wall=11957, train_wall=10246, length_loss=3.72187]

@alphadl
Copy link
Author

alphadl commented Mar 1, 2020

@yinhanliu @omerlevy @Marjan-GH

Hi, Dear authors, I have trained the large scale MaskPredict (Hidden size: 1024/4096, vocab size: 6w+), the #Param is ~270M. Because the amount of parameters increases, the translation effect should be better than the normal scale MaskPredict (512/2048) as expected ! However, the BLEU score of the large scale model on ENDE newstest14 is only ~26.

I'm pretty sure the model has converged, Some indicators are shown below:

The loss of the latest large scale model that I used for evaluation is as follows:

loss=2.915; nll_loss=0.833;  ppl=1.78; length_loss=2.88968; lr=0.000119598

which looks obviously better than the normal scale MaskPredict, where I reproduced your result on same ENDE dataset and the BLEU can reach ~27, the loss of that normal model is:

loss=3.136; nll_loss=1.146; ppl=2.21; length_loss=3.04; lr=0.000117369

So, I am wondering that if the MaskPredict model only fits the base(512/2048) scale and it does not work under the large scale setting ???

Looking forward to your reply

Best

@jungokasai
Copy link

jungokasai commented Mar 2, 2020

That seems a bit strange. The perplexity and length loss in validation are smaller, so I would suspect the large transformer would be at least as good as the base one in BLEU as well. It does not look like a training issue. Just for a sanity check, could you check the BLEU and loss on the validation data with both base and large? Roughly speaking, there should be correlation between BLEU and the loss, but if not there might be something wrong with the inference. Otherwise it might be overfitting? I wouldn't have expected this to happen though with dropout 0.3. Also, please make sure you are distilling from the same large autoregressive transformer.

@alphadl
Copy link
Author

alphadl commented Mar 2, 2020

That seems a bit strange. The perplexity and length loss in validation are smaller, so I would suspect the large transformer would be at least as good as the base one in BLEU as well. It does not look like a training issue. Just for a sanity check, could you check the BLEU and loss on the validation data with both base and large? Roughly speaking, there should be correlation between BLEU and the loss, but if not there might be something wrong with the inference. Otherwise it might be overfitting? I wouldn't have expected this to happen though with dropout 0.3.

Hi @jungokasai ,

The BLEU scores on validation set using the best single checkpoint are:

large scale model↓

BLEU = 15.74, 38.1/19.3/11.3/7.4 (BP=1.000, ratio=1.946, syslen=62394, reflen=32064)

base scale model↓

BLEU = 15.19, 37.3/18.7/11.0/7.0 (BP=1.000, ratio=1.935, syslen=62578, reflen=32332)

There do exist positive correlation between validation BLEU and the loss.

However, the BLEU scores on test set with the best single checkpoint are:

large scale model↓

BLEU = 25.77, 57.7/31.5/19.5/12.5 (BP=1.000, ratio=1.008, syslen=65745, reflen=65214)

base scale model↓

BLEU = 26.81, 58.9/32.7/20.4/13.2 (BP=1.000, ratio=1.005, syslen=64820, reflen=64496)

This phenomenon is strange. Can it be said that the MaskPredict architecture is not suitable for large scale?

B.T.W., all my distilled data is derived from a pretrained powerful big AT model, refer to that issue response

@omerlevy
Copy link

omerlevy commented Mar 2, 2020

Hi Liam,
When increasing the model size, you usually need to retune the optimization hyperparameters (e.g. learning rate, dropout). I would start with the recommended values for Transformer-Large and tweak it from there.
Hope this helps!

@alphadl
Copy link
Author

alphadl commented Mar 2, 2020

@omerlevy
Hi Levy, Thanks for your prompt reply~

Because of my relative limited computing resources, the experiment took a long time. Looking forward to your results! This will be helpful for researchers who follow this paper, thank you!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants