Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issues with pretrain mBART models #2120

Open
fansiawang opened this issue May 11, 2020 · 35 comments
Open

issues with pretrain mBART models #2120

fansiawang opened this issue May 11, 2020 · 35 comments

Comments

@fansiawang
Copy link

❓ Questions and Help

Thanks for releasing the mbart models! Referring to #1758 I reproduced the same results, which is basically close to the results of the paper. Therefore, I use my own data (Japanese-Korean) to fine-tune on the pre-trained model mbart.cc25. Compared with the results of binglingual (baseline: 6.01 BLEU-SBP), finetune has improved by 2.64 points (fintune: 9.65 BLEU-SBP).

Although I got good results, the absolute quality is still relatively low. Due to the pre-training vocabulary is very large, the model after finetuning is very large. I hope you can provide a tutorial of how to pretrain mBART on our own data.

What is your question?

  1. Can I pre-train the mBART model if I only have sentence level monolingual data?
  2. How to choose the languages for pre-training? Are there any principles?
  3. How to set the vocabulary size? Is it related to the number of languages or the size of training data?
  4. If it is not use the sentencepiece tool, but my own preprocessed data, can I use the framework to pre-train the mBART model?
  5. How long will it take for the 1080Ti or 2080Ti (I only have 8 GPU) to pre-train the mBART model?
@hischen
Copy link

hischen commented May 21, 2020

❓ Questions and Help

Thanks for releasing the mbart models! Referring to #1758 I reproduced the same results, which is basically close to the results of the paper. Therefore, I use my own data (Japanese-Korean) to fine-tune on the pre-trained model mbart.cc25. Compared with the results of binglingual (baseline: 6.01 BLEU-SBP), finetune has improved by 2.64 points (fintune: 9.65 BLEU-SBP).

Although I got good results, the absolute quality is still relatively low. Due to the pre-training vocabulary is very large, the model after finetuning is very large. I hope you can provide a tutorial of how to pretrain mBART on our own data.

What is your question?

  1. Can I pre-train the mBART model if I only have sentence level monolingual data?
  2. How to choose the languages for pre-training? Are there any principles?
  3. How to set the vocabulary size? Is it related to the number of languages or the size of training data?
  4. If it is not use the sentencepiece tool, but my own preprocessed data, can I use the framework to pre-train the mBART model?
  5. How long will it take for the 1080Ti or 2080Ti (I only have 8 GPU) to pre-train the mBART model?

Hey,fansiawang,sorry to be a bother.
While I fine-tuning mBART pre-trained CC25 model on wmt16 en-ro dataset,it was always running into out of memory problems.Like this

RuntimeError: CUDA out of memory. Tried to allocate 978.00 MiB (GPU 0; 10.76 GiB total capacity; 8.73 GiB already allocated; 659.12 MiB free; 9.26 GiB reserved in total by PyTorch)

would you pls give me some help.
I fine-tuned even on a single GPU RTX2080TI.
Since the wmt16 en-ro dateset is about 0.6M,hou large is your dataset (Japanese-Korean) ?
and what kind of gpu you used for fine-tune?
Thanks!

@vikrant97
Copy link

vikrant97 commented May 21, 2020

@ngoyal2707 Hi, I have same issue and the released pretrained model is too large and is not trained on the languages I am interested in . So I want to pretrain mBART on my own data. Can you please provide any pointers on how to do that? Thanks!

@fansiawang
Copy link
Author

fansiawang commented May 25, 2020

The OOM error you mentioned has nothing to do with the size of training data used. The root cause is the vocabulary used by the pre-trained model is too large, thus too many parameters cannot be loaded on the ordinary GPU,such as 1080Ti, 2080Ti and so on.

There are three options to choose from here:

  1. use cpu training, but the speed will be very slow;
  2. use a machine with large memory like P100;
  3. cut the pre-trained model. This is our best choice.

Most of the words in the large vocabulary used by the original pre-training model are not actually used in the finetune process, so this part of redundant information can be removed. In the NMT model, the embedding matrix actually accounts for the largest proportion of parameters, so our cutting work mainly focuses on reducing the embedding matrix in the pre-trained model.

Next I will introduce how to cut the pre-trained model.

  1. Get a new vocabulary based on finetuning data. Generally speaking, the size of this vocabulary is much smaller than that of the pre-trained model.
  2. Find the corresponding position in the old vocabulary according to the new vocabulary.
  3. Obtain part of the embedding matrix according to the position obtained in the previous step, then we can get a smaller embedding only related to our finetuning data.
  4. Keep other parameters unchanged, only replace the embedding matrix with the new embedding matrix obtained in the previous step, and save the model again.

Generally speaking, this new model will be much smaller than the original pre-trained model, enough to be loaded onto a normal GPU for training.

Note: when finetune on a cropped pretraining model, be sure to generate training data in binary format with the new vocabulary.

@hischen
Copy link

hischen commented May 29, 2020

3. cording to the position obtained in the previous step, then we can get a smaller embedding only related to our finetun

Thanks very much !!! It is so kind of you!

@ddaspit
Copy link

ddaspit commented Jun 4, 2020

@fansiawang Is there any code that illustrates your proposed method for cutting the pre-trained model? Thank you.

@SunbowLiu
Copy link

Note: when finetune on a cropped pretraining model, be sure to generate training data in binary format with the new vocabulary.

Hello Fansia,

Thank you for sharing your experience which really helps me a lot.
You mentioned embedding pruning method is intuitively appealing, it would be much better if you could make the codes freely available.

Best,
Xuebo Liu

@ddaspit
Copy link

ddaspit commented Jun 22, 2020

@SunbowLiu Here is a script that I wrote to reduce the size of the pre-trained model by pruning the word embeddings for fine-tuning:

import argparse
import os
from typing import List

import torch

from fairseq.data import Dictionary


def load_dict(langs: List[str], path: str) -> Dictionary:
    d = Dictionary.load(path)
    for l in langs:
        d.add_symbol(f"[{l}]")
    d.add_symbol("<mask>")
    return d


def main() -> None:
    parser = argparse.ArgumentParser(description="Trims pre-trained mBART model for fine-tuning.")
    parser.add_argument("--pre-train-dir", type=str, required=True, help="The pre-trained mBART model directory.")
    parser.add_argument("--ft-dict", type=str, required=True, help="The fine-tuning model dictionary.")
    parser.add_argument("--langs", type=str, required=True, help="The pre-trained model languages.")
    parser.add_argument("--output", type=str, required=True, help="The trimmed mBART model.")
    args = parser.parse_args()

    langs = args.langs.split(",")
    pre_dict = load_dict(langs, os.path.join(args.pre_train_dir, "dict.txt"))
    ft_dict = load_dict(langs, args.ft_dict)
    data = torch.load(os.path.join(args.pre_train_dir, "model.pt"))
    model = data["model"]

    mapping: List[int] = []
    for i in range(len(ft_dict)):
        word = ft_dict[i]
        mapping.append(pre_dict.index(word))

    for name in ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]:
        pre_tensor: torch.Tensor = model[name]
        ft_tensor = torch.zeros(
            [len(ft_dict), 1024], dtype=pre_tensor.dtype, layout=pre_tensor.layout, device=pre_tensor.device,
        )
        for ft_i, pre_i in enumerate(mapping):
            ft_tensor[ft_i] = pre_tensor[pre_i]
        model[name] = ft_tensor

    torch.save(data, args.output)


if __name__ == "__main__":
    main()

Here is an example of how you run it (I saved the script as trim_mbart.py):

python trim_mbart.py --pre-train-dir ./mbart.cc25 --ft-dict ./ft/dict.xt --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN --output ./ft/model.pt

@SunbowLiu
Copy link

@SunbowLiu Here is a script that I wrote to reduce the size of the pre-trained model by pruning the word embeddings for fine-tuning:

import argparse
import os
from typing import List

import torch

from fairseq.data import Dictionary


def load_dict(langs: List[str], path: str) -> Dictionary:
    d = Dictionary.load(path)
    for l in langs:
        d.add_symbol(f"[{l}]")
    d.add_symbol("<mask>")
    return d


def main() -> None:
    parser = argparse.ArgumentParser(description="Trims pre-trained mBART model for fine-tuning.")
    parser.add_argument("--pre-train-dir", type=str, required=True, help="The pre-trained mBART model directory.")
    parser.add_argument("--ft-dict", type=str, required=True, help="The fine-tuning model dictionary.")
    parser.add_argument("--langs", type=str, required=True, help="The pre-trained model languages.")
    parser.add_argument("--output", type=str, required=True, help="The trimmed mBART model.")
    args = parser.parse_args()

    langs = args.langs.split(",")
    pre_dict = load_dict(langs, os.path.join(args.pre_train_dir, "dict.txt"))
    ft_dict = load_dict(langs, args.ft_dict)
    data = torch.load(os.path.join(args.pre_train_dir, "model.pt"))
    model = data["model"]

    mapping: List[int] = []
    for i in range(len(ft_dict)):
        word = ft_dict[i]
        mapping.append(pre_dict.index(word))

    for name in ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]:
        pre_tensor: torch.Tensor = model[name]
        ft_tensor = torch.zeros(
            [len(ft_dict), 1024], dtype=pre_tensor.dtype, layout=pre_tensor.layout, device=pre_tensor.device,
        )
        for ft_i, pre_i in enumerate(mapping):
            ft_tensor[ft_i] = pre_tensor[pre_i]
        model[name] = ft_tensor

    torch.save(data, args.output)


if __name__ == "__main__":
    main()

Here is an example of how you run it (I saved the script as trim_mbart.py):

python trim_mbart.py --pre-train-dir ./mbart.cc25 --ft-dict ./ft/dict.xt --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN --output ./ft/model.pt

Hi Damien,

Thank you so much for your wonderful code.

How about the final translation performance w/ or w/o this pruning technique? I have a concern about the numerical stability after pruning, e.g., the denominator of the final softmax layer becomes much smaller. Maybe it is an easy thing that can be dealt with during fine-tuning?

What is your opinion of this concern?

Thank you again for your code!

Best,
Xuebo Liu

@ddaspit
Copy link

ddaspit commented Jun 23, 2020

Using this script, I have tried to reproduce the EN-RO results from the paper, but I only get a BLEU of ~36. It's close, but not quite as good. It certainly could be because of the pruning. I am training on a very different setup (single GPU), so it might be because I haven't got the batch size and learning rate schedule quite right to reproduce the results from the paper. Also, the tokenization steps are kind of tricky for EN-RO, so I might not have gotten that correct.

@SunbowLiu
Copy link

@ddaspit Thank you for your reply. Using the pretrained model, I can obtain the same BLEU score (37.8) with the paper. Now I start to fine-tune the model by myself and I will share the new results with you.

@sshleifer
Copy link
Contributor

@ddaspit Do you have code to create ft/dict.txt, e.g. for WMT en-ro?

The downloads for mbart-cc25 and mbart-enro provide identical dict.txt afaict?

@SunbowLiu
Copy link

@sshleifer, just use the default fairseq preprocessor with --joined-dictionary and the fine-tuning corpus.
Following is my results:
WMT'16 EN-RO baseline: 37.7
With emb pruning: 37.3 (Fast and lightweight)
I also found that small label smoothings (0.1/0.05) help emb pruning increase to 37.4/37.5, but this is just a simple try and I don't know can this small trick help other high-resource benchmarks or not.

@ddaspit
Copy link

ddaspit commented Jul 21, 2020

@sshleifer Here is the code I used to generate the vocabulary file.

import argparse
from glob import glob

from fairseq.data import Dictionary
from fairseq.tokenizer import tokenize_line

def pad_dict(d: Dictionary, num_extra_symbols: int, padding_factor: int = 8) -> None:
    i = 0
    while (len(d) + num_extra_symbols) % padding_factor != 0:
        symbol = f"madeupword{i:04d}"
        d.add_symbol(symbol, n=0)
        i += 1

def main() -> None:
    parser = argparse.ArgumentParser(description="Build vocabulary from corpus data.")
    parser.add_argument("--corpus-data", type=str, required=True, help="The path pattern (glob) to all tokenized corpus files (train, test, val).")
    parser.add_argument("--langs", type=str, required=True, help="The pre-trained model languages.")
    parser.add_argument("--output", type=str, required=True, help="The vocabulary file.")
    args = parser.parse_args()

    langs = args.langs.split(",")
    ft_dict = Dictionary()
    for data_path in glob(args.corpus_data):
        Dictionary.add_file_to_dictionary(data_path, ft_dict, tokenize_line, 4)
    ft_dict.finalize(padding_factor=0)
    pad_dict(ft_dict, len(langs) + 1)
    ft_dict.save(args.output)

if __name__ == "__main__":
    main()

Here is an example of how you run it (I saved the script as build_vocab.py:

python build_vocab.py --corpus-data "./ft/*.spm.*" --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN --output ./ft/dict.txt

@SunbowLiu Those results look good.

@sshleifer
Copy link
Contributor

sshleifer commented Jul 21, 2020

Very cool. I am trying to trim down the finetuned mbart-en-ro to make it faster/lighter.
trim_batch.py and build_vocab.py work well, but I think I need to prune the sentencepiece.bpe.model so that it can take english text and produce ids that are in the new dict.txt. Have either of you done that? Then the "trimmed" model will be able to support any english text.

import sentencepiece as spm

sp_model = spm.SentencePieceProcessor()

# same file as original sentence.bpe.model
sp_model.Load('/home/shleifer/fairseq/enro_trimmed/sentence.bpe.model')

sp_model.encode_as_ids("UN Chief Says There Is No Military Solution in Syria")
=> [8273, 127872, 25915, 6, 8621, 2070, 437, 67484, 52, 187894, 22, 51711]

@ddaspit
Copy link

ddaspit commented Jul 22, 2020

@sshleifer I haven't needed to do anything like that. I just use spm_encode to encode the data before calling fairseq-preprocess. If you need to get the pruned vocabulary ids, you can probably do something like this:

  1. Load dict.txt using the Dictionary class in fairseq.
  2. Use SentencePieceProcessor.EncodeAsPieces to encode the sentence.
  3. Convert the array of pieces to a space delimited string.
  4. Call Dictionary.encode_line on the string to get the ids.

@sshleifer
Copy link
Contributor

Does anyone have a link to the enro data? I tried to rerun finetuning on a new machine and I got

  File "fairseq/data/data_utils_fast.pyx", line 50, in fairseq.data.data_utils_fast.batch_by_size_fast
    assert max_tokens <= 0 or sample_len <= max_tokens, (
AssertionError: sentence at index 1524 of size 130 exceeds max_tokens limit of 128!

I have 610,319 lines 14,280,744 words 98,251,174 chars in train.ro.

@SunbowLiu
Copy link

http://data.statmt.org/wmt16/translation-task/training-parallel-ep-v8.tgz
http://opus.lingfil.uu.se/SETIMES2.php
I concated this two corpus and finally got 612,422 instances.

@kr-sundaram
Copy link

kr-sundaram commented Aug 5, 2020

Hi Guys,

I am new to machine translation and need your help to understand the process of using pretrained mBART model for fine-tuning of EN-DE machine translation task.

I just wanted to make sure whether my understanding of processes needed to be followed are correct or not in order to acheive my goal.

  1. Create a corpus for DE (src) -> EN (trg), Let's say train.de, train.en, valid.de, valid.en, test.de and test.en.

  2. Update all the sentence pairs of corpus with the tokenized version, done with the help of sentencepiece models i.e. sentence.bpe.model. Lets say my tokenized corpus name is train.spm.de, train.spm.en,valid.spm.de, valid.spm.en, test.spm.de and test.spm.en
    for example. If earlier i had a sentence in my corpus as 'I saw a girl with a telescope.' so now after tokenization, the sentences in corpus will look like '▁I ▁saw ▁a ▁girl ▁with ▁a ▁ telescope .' (Please clarify if this is correct and this is how it should be.)

  3. Now, once all files are updated and kept at particular directory lets say ft, dict.txt file will be created by executing build_vocab.py python script and we need to provide the parameter for corpus as --corpus-data "./ft/.spm.", as corpus data will take care for all 6 files that is created in step 2. I have attached dict that has been generated for my dataset. Please have a look if it is fine.

dict.txt

  1. Once dict.txt file is generated, we will have to go for pruning of the pretrained model by using trim_mbart.py python script. So that it will generate lighter version of pretrained model for EN-DE language pair. (I performed till this step but the trimmed model which got generated is of 5GB, I dont know if i am doing something wrong.)

  2. Now, we have to do the preprocess of the dataset using fairseq preprocess command as below:

DICT=dict.txt
fairseq-preprocess
--source-lang de
--target-lang en
--trainpref ./ft/train.spm
--validpref ./ft/valid.spm
--testpref ./ft/test.spm
--destdir ./output
--thresholdtgt 0
--thresholdsrc 0
--srcdict ${DICT}
--tgtdict ${DICT}
--workers 70 --> (could you plese let me know what is the use of this parameter, what it should be if i have access to only 1 GPU.)

  1. Train the model using below command:

PRETRAIN= path_of_trimmed_bart_pretrained_model
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

fairseq-train ./ft
--encoder-normalize-before --decoder-normalize-before
--arch mbart_large --layernorm-embedding
--task translation_from_pretrained_bart
--source-lang de_DE --target-lang en_XX
--criterion label_smoothed_cross_entropy --label-smoothing 0.2
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)'
--lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0
--max-tokens 1024 --update-freq 2
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints
--seed 222 --log-format simple --log-interval 2
--restore-file $PRETRAIN
--reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler
--langs $langs
--ddp-backend no_c10d --> (please explain me the work of this parameter as well)

  1. Generate the translation using below command:

model_dir=MBART_finetuned_deen

fairseq-generate path_2_data
--path $model_dir/model.pt
--task translation_from_pretrained_bart
--gen-subset test
-t en_XX -s de_DE
--bpe 'sentencepiece'
--sentencepiece-model $model_dir/sentence.bpe.model \ --> (this sentencepiece model will be the one which comes with pretrained model, right?)
--sacrebleu --remove-bpe 'sentencepiece'
--max-sentences 32
--langs $langs > de_en

Thanks in advance for helping me out!
Kindly correct me if i am wrong somewhere.
Thanks for having patience! :)

@ddaspit
Copy link

ddaspit commented Aug 6, 2020

@kr-sundaram Everything looks correct. You definitely want to use the sentencepiece model that was used for the pretrained model when tokenizing and detokenizing your data. Your dictionary file looks good. Regarding the trimmed model size, 5GB sounds correct. That was around the size I got with a similar vocabulary size. For the fairseq-preprocess call, --workers 70 is fine. It just specifies the number of worker processes that are spawned to perform the preprocessing. Make sure that you use the path to the output from preprocessing in the fairseq-train call. The --ddp-backend no_c10d parameter tells fairseq to use the old distributed data parallel implementation. I'm not sure why it is being used for mBART training. I think that the newer (c10d) implementation can have issues with certain architectures, which may be the case for mBART.

@kr-sundaram
Copy link

kr-sundaram commented Aug 6, 2020

Thanks @ddaspit for your reply!!

Hey, I have got some more doubts.

  1. Since, 5GB is a huge size, can you please explain me the steps, if you are aware of, to be followed for pruning of sentencepiece.bpe.model model which was there with pre-trained model, such that size of the model could be more reduced?

  2. In the step-2 , see above, when i tokenized the sentence using 'sentence.bpe.model', the sentences in my corpus looks like below:
    ▁like , ▁" w hy ▁do ▁not ▁you ▁send ▁her ▁that ▁i ▁never ▁asked ?
    I just try to bring your attention to the whitespaces which are there between two subwords. I did not not delete this extra whitespace while doing the corpus preprocessing as stated in step-5. So whether it is correct way of doing or should i delete the whitespace and then do the fairseq-preprocessing?

  3. I used the below commands to generate the translated text as i was not sure which is a correct one and received error for all of them. All parameters of these commands are same except --langs. Command-a, and b gave the produces same error whether command c gives some more additional information in addition to command a see below for more understanding:

a.

!fairseq-generate '/content/drive/My Drive/mBART/output/'
--path '/content/drive/My Drive/mBART/model.pt'
--task translation_from_pretrained_bart
--gen-subset test
-t en -s de
--bpe 'sentencepiece'
--sentencepiece-model '/content/drive/My Drive/mBART/mbart.cc25/sentence.bpe.model'
--sacrebleu --remove-bpe 'sentencepiece'
--max-sentences 32
--langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN > de_en

Error received

Traceback (most recent call last): File "/usr/local/bin/fairseq-generate", line 8, in <module> sys.exit(cli_main()) File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 281, in cli_main main(args) File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 38, in main return _main(args, sys.stdout) File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 87, in _main suffix=getattr(args, "checkpoint_suffix", ""), File "/usr/local/lib/python3.6/dist-packages/fairseq/checkpoint_utils.py", line 192, in load_model_ensemble filenames, arg_overrides, task, strict, suffix, File "/usr/local/lib/python3.6/dist-packages/fairseq/checkpoint_utils.py", line 213, in load_model_ensemble_and_task model.load_state_dict(state["model"], strict=strict, args=args) File "/usr/local/lib/python3.6/dist-packages/fairseq/models/fairseq_model.py", line 93, in load_state_dict return super().load_state_dict(new_state_dict, strict) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1045, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for BARTModel: Unexpected key(s) in state_dict: "encoder.layernorm_embedding.weight", "encoder.layernorm_embedding.bias", "decoder.layernorm_embedding.weight", "decoder.layernorm_embedding.bias".

b.
!fairseq-generate '/content/drive/My Drive/mBART/output/' --path '/content/drive/My Drive/mBART/model.pt' --task translation_from_pretrained_bart --gen-subset test -t en -s de --bpe 'sentencepiece' --sentencepiece-model '/content/drive/My Drive/mBART/mbart.cc25/sentence.bpe.model' --sacrebleu --remove-bpe 'sentencepiece' --max-sentences 32 --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

Error received same as what i got after executing command-a

c.
!fairseq-generate '/content/drive/My Drive/mBART/output/' --path '/content/drive/My Drive/mBART/model.pt' --task translation_from_pretrained_bart --gen-subset test -t en -s de --bpe 'sentencepiece' --sentencepiece-model '/content/drive/My Drive/mBART/mbart.cc25/sentence.bpe.model' --sacrebleu --remove-bpe 'sentencepiece' --max-sentences 32 --langs de_en

Error is same as stated for command-a but, this time, with having some has more additional information than what i got after execution of command-a&b.

RuntimeError: Error(s) in loading state_dict for BARTModel:
Unexpected key(s) in state_dict: "encoder.layernorm_embedding.weight", "encoder.layernorm_embedding.bias", "decoder.layernorm_embedding.weight", "decoder.layernorm_embedding.bias".
size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([23536, 1024]) from checkpoint, the shape in current model is torch.Size([23512, 1024]).
size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([23536, 1024]) from checkpoint, the shape in current model is torch.Size([23512, 1024]).
size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([23536, 1024]) from checkpoint, the shape in current model is torch.Size([23512, 1024]).
`

Can you help me understanding the issue and rectify it and which one amongst above 3 should be used to generate translate of test set? This error seems to be confusing. as i had run only to check if the code works or not, and when i found its working fine, I interrupted the model training in the middle. Do you think, it might be because of stopping the execution in the middle. For me its doesn't make sense as my understanding is size of the checkpoints and current model should be same throughout the model training. Kindly clarify if i am wrong!

  1. For how many hours your model ran and for how many training and validation set? I need to know as i will be having access to only one GPU so want to estimate my model training time as well.

  2. How did you calculate your BLEU score, do you have any scripts or examples to get an idea from?

Thank you!

@ddaspit
Copy link

ddaspit commented Aug 7, 2020

@kr-sundaram

  1. Reducing the size of the sentencepiece model won't affect the size of the pre-trained model. The size of the vocabulary (dict.txt) is what affects the size of the pre-trained model. The original vocabulary has ~250000 tokens. Your trimmed vocabulary has only ~23000 tokens. This allows us to remove all of the unused tokens from the word embedding layers in the pre-trained model, which will greatly reduce the amount of GPU memory that is required when training.
  2. You shouldn't delete the whitespace. The spaces are there to delimit the tokens.
  3. You need the --langs parameter when inferencing. That is what is causing the "size mismatch" errors in c. I am guessing that the --path parameter is specified incorrectly. It should be pointing at a checkpoint from your training run and not at the pre-trained model. I specify the directory to save checkpoints while training using the --save-dir parameter.
  4. I didn't save information on how long it took to train. On a single GPU, you should expect several hours of training.
  5. I use sacrebleu to calculate the BLEU score. fairseq-generate will compute the BLEU score using sacrebleu if you specify the --sacrebleu parameter. Here is the fairseq-generate call that I use:
fairseq-generate $CORPUS/preprocessed \
--path $CORPUS/checkpoints/checkpoint_best.pt \
--task translation_from_pretrained_bart \
--gen-subset test \
-s $SRC \
-t $TRG \
--remove-bpe sentencepiece \
--sacrebleu \
--max-sentences 32 \
--langs $LANGS > $CORPUS/results

@kr-sundaram
Copy link

kr-sundaram commented Aug 7, 2020

Thanks @ddaspit for the reply!!

As you said in point-1, in my vocabulary i have 23506 tokens, then i used trim_mbart.py script to get the pre-trained model on my dict, but still size of my new pre-trained model did not reduce much. Size of actual pre-trained model is around 5.7 GB and my new pre-trained model is around 5.3 GB.

Is there any way for reducing the model size much than what i followed? Or what size i have now is the min size i could get ever?

Finally, with your kind help and guidance, I was able to fine-tune with small DE -> EN dataset for test run and fairseq-generate is also working. :)

I wanted to run fairseq-interactive, which is still not working for me. Below is the command i used to run it:

!fairseq-interactive  '/content/drive/My Drive/mBART/output' 
    --path '/content/drive/My Drive/mBART/model_ckpt/checkpoint_best.pt'  
    --beam 5 --source-lang de_DE --target-lang en_XX 
    --bpe sentencepiece --sentencepiece-model '/content/drive/My Drive/mBART/mbart.cc25/sentence.bpe.model'

Could you please share your command that you executed for fairseq-interactive?

I just ran for 1 epoch to check if fairseq-generate works this time for me, that's why translation is very poor now.
Sentences that are present in results file are in ascending order with respect to number of tokens per sentence.
That's why 5812th sentence from my test file was taken as the first sentence. I understand S is Source, T is Target, but i am not sure about other three parameters i.e. H, D, P. Could you please explain them? and what is the significance of numbers associated with H, D and P? and which one is actually telling the BLEU score.?

S-5812	wir alle sind europäer.[de_DE]
T-5812	we are all europeans.
H-5812	-2.0999746322631836	it is a.
D-5812	-2.0999746322631836	it is a.
P-5812	-4.6554 -0.6944 -3.1002 -3.7654 -0.2138 -0.1706

@ddaspit
Copy link

ddaspit commented Aug 10, 2020

@kr-sundaram
Although the size of the file does not reduce a lot, the memory usage when training is significantly smaller. I was able to fine-tune a mBART model using a GPU with 8GB of memory after I trimmed the model. Unaltered, you would probably need a GPU with 16GB of memory (maybe 11GB if you reduce the batch size and use mixed precision training).

I haven't used fairseq-interactive with mBART. I would assume that you would run it with similar parameters to a fairseq-generate call.

This is how you would interpret the results from fairseq-generate:

  • S: source sentence
  • T: target sentence (reference)
  • H: best hypothesis with score from beam search (after BPE has been removed)
  • D: detokenized hypothesis with score from beam search (no tokenizers are used so it is the same as H)
  • P: scores for each token in the hypothesis

The BLEU score should be at the end of the output. Look for a line like this

Generate test with beam=5: BLEU = ...

@XiaoqingNLP
Copy link

@fansiawang hi ,do you have solved your question ? I have the same question as yours:

Can I pre-train the mBART model if I only have sentence level monolingual data?
How to choose the languages for pre-training? Are there any principles?
How to set the vocabulary size? Is it related to the number of languages or the size of training data?
If it is not use the sentencepiece tool, but my own preprocessed data, can I use the framework to pre-train the mBART model?
How long will it take for the 1080Ti or 2080Ti (I only have 8 GPU) to pre-train the mBART model?

@learnercat
Copy link

Thanks to @ddaspit for the awesome scripts to trim mBART for low resource GPUs. I run the script to trim the model 'mbart-large-cc25'.
python trim_mbart.py --pre-train-dir ./mbart.cc25 --ft-dict ./ft/dict.txt --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN --output ./ft/model.pt
I got an error "KeyError: 'model'" as below:
Screenshot 2021-04-28 at 5 04 32 PM
Please help.

@ddaspit
Copy link

ddaspit commented Apr 29, 2021

@learnercat The script was designed to work with the pre-trained model files linked from the mBART README in the fairseq repo. I haven't tried it with the models files from huggingface. I am guessing that the models files from huggingface have a different internal structure.

@learnercat
Copy link

@ddaspit Thank you very much. Now it works with from the mBART README.

@akshayg08
Copy link

fairseq-generate ~/CM/mt_enghinglish/transliterated_normalized/destdir/ --path checkpoint_best.pt -s en_XX -t hi_IN --beam 5 --batch-size 32 --task translation_from_pretrained_bart --gen-subset test --bpe "sentencepiece" --sentencepiece-model ../mbart.cc25.v2/sentence.bpe.model --remove-bpe "sentencepiece" --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

I am using this command to fine tune EN-HI model but I am getting the size mismatch error:
RuntimeError: Error(s) in loading state_dict for BARTModel:
size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([47336, 1024]) from checkpoint, the shape in current model is torch.Size([47360, 1024]).
size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([47336, 1024]) from checkpoint, the shape in current model is torch.Size([47360, 1024]).
size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([47336, 1024]) from checkpoint, the shape in current model is torch.Size([47360, 1024]).

Size of dict.txt created from the build_vocab scripts is 47330. And in the pruned model the size of the embedding matrix from encoder is - torch.Size([47360, 1024]). If the vocab size is 47330 then why does embedding matrix has 47360 rows ?
Could you please help ?

@zide05
Copy link

zide05 commented Jun 2, 2021

Is there any code released for pretraining mbart ?

@firdota
Copy link

firdota commented Jul 22, 2021

Is there any code released for pretraining mbart ?

i have the same question

@firdota
Copy link

firdota commented Jul 29, 2021

hello, do you have any progress in the pretraining of mbart? @fansiawang

@IamAdiSri
Copy link

IamAdiSri commented Mar 17, 2022

fairseq-generate ~/CM/mt_enghinglish/transliterated_normalized/destdir/ --path checkpoint_best.pt -s en_XX -t hi_IN --beam 5 --batch-size 32 --task translation_from_pretrained_bart --gen-subset test --bpe "sentencepiece" --sentencepiece-model ../mbart.cc25.v2/sentence.bpe.model --remove-bpe "sentencepiece" --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

I am using this command to fine tune EN-HI model but I am getting the size mismatch error: RuntimeError: Error(s) in loading state_dict for BARTModel: size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([47336, 1024]) from checkpoint, the shape in current model is torch.Size([47360, 1024]). size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([47336, 1024]) from checkpoint, the shape in current model is torch.Size([47360, 1024]). size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([47336, 1024]) from checkpoint, the shape in current model is torch.Size([47360, 1024]).

Size of dict.txt created from the build_vocab scripts is 47330. And in the pruned model the size of the embedding matrix from encoder is - torch.Size([47360, 1024]). If the vocab size is 47330 then why does embedding matrix has 47360 rows ? Could you please help ?

I'm running into this same issue. Were you able to figure out what the problem was? @akshayg08

@jaspock
Copy link

jaspock commented Mar 17, 2022

@IamAdiSri, I guess you are having a similar problem to the embedding size mismatch issue discussed here (see the whole thread, anyway) for the similar mBART50 and MM100 models. In your case, mBART supports 25 languages, and if you add 47330+25 and round up to the closest multiple of 8 (for some efficiency reasons, dictionary sizes are multiple of 8) you get 47360. Probably, you will be able to move on if you manually add mBART's language codes to the dictionary followed by 5 extra symbols such as "madeupword01", "madeupword02", etc. According to mBART's code, these languages are: ar_AR, cs_CZ, de_DE, en_XX, es_XX, et_EE, fi_FI, fr_XX, gu_IN, hi_IN, it_IT, ja_XX, kk_KZ, ko_KR, lt_LT, lv_LV, my_MM, ne_NP, nl_XX, ro_RO, ru_RU, si_LK, tr_TR, vi_VN, zh_CN.

@chirico85
Copy link

@ddaspit Thanks a lot for the scripts and advice and to the rest for the great discussion, which helped me out with some questions.
I managed to prune the mbart pretrained model and actually fine tune on german paraphrases for translation task complex text -> simple text (work by @louismartin). However, inferencing provides no actual text, just one dot or one comma. Did some of you experience similar problems with this or the usual translation tasks after pruning?

@tianshuailu
Copy link

@ddaspit Hi ddaspit, I'm following the instructions you provided above and it's super helpful so thank you so much. But I have a few questions.

  1. After the model was trained for many hours, it stopped with the following error, does this mean that the model has completed the training, or is it some error that I need to fix.

Traceback (most recent call last):
File "/home/user/ltian/anaconda3/bin/fairseq-train", line 33, in
sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())
File "/home/user/ltian/anaconda3/fairseq/fairseq_cli/train.py", line 352, in cli_main
distributed_utils.call_main(args, main)
File "/home/user/ltian/anaconda3/fairseq/fairseq/distributed_utils.py", line 301, in call_main
main(args, **kwargs)
File "/home/user/ltian/anaconda3/fairseq/fairseq_cli/train.py", line 125, in main
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
File "/home/user/ltian/anaconda3/lib/python3.7/contextlib.py", line 74, in inner
return func(*args, **kwds)
File "/home/user/ltian/anaconda3/fairseq/fairseq_cli/train.py", line 208, in train
log_output = trainer.train_step(samples)
File "/home/user/ltian/anaconda3/lib/python3.7/contextlib.py", line 74, in inner
return func(*args, **kwds)
File "/home/user/ltian/anaconda3/fairseq/fairseq/trainer.py", line 592, in train_step
raise FloatingPointError("gradients are Nan/Inf")
FloatingPointError: gradients are Nan/Inf

  1. And in the command, it doesn't specify the criterion for fairseq-train to stop training such as --max-epoch, --max-update, --stop-time-hours, --stop-min-lr right? So how do we know when will the model complete the training.

fairseq-train /srv/scratch3/ltian/output
--encoder-normalize-before --decoder-normalize-before
--arch mbart_large --layernorm-embedding
--task translation_from_pretrained_bart
--source-lang hi_IN --target-lang en_XX
--criterion label_smoothed_cross_entropy --label-smoothing 0.2
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)'
--lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0
--max-tokens 1024 --update-freq 2
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints
--seed 222 --log-format simple --log-interval 2
--save-dir /srv/scratch3/ltian
--restore-file /srv/scratch3/ltian/model.pt
--reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler
--langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
--ddp-backend no_c10d

  1. And this is the checkpoint list that it generated. Seems like it only trained a bit in epoch 2 and then just stopped. Do you think it is normal or is it because I used too much data, the parallel data contains 5.6 million sentence pairs. And may I ask which one is the final result of the training? Is it checkpoint_best.pt or checkpoint_last.pt or maybe it updated model.pt

checkpoint_1_100000.pt checkpoint_1_70000.pt checkpoint_1_85000.pt checkpoint_2_105000.pt model.pt
checkpoint_1_60000.pt checkpoint_1_75000.pt checkpoint_1_90000.pt checkpoint_best.pt
checkpoint_1_65000.pt checkpoint_1_80000.pt checkpoint_1_95000.pt checkpoint_last.pt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests