-
Notifications
You must be signed in to change notification settings - Fork 6.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
issues with pretrain mBART models #2120
Comments
Hey,fansiawang,sorry to be a bother.
would you pls give me some help. |
@ngoyal2707 Hi, I have same issue and the released pretrained model is too large and is not trained on the languages I am interested in . So I want to pretrain mBART on my own data. Can you please provide any pointers on how to do that? Thanks! |
The OOM error you mentioned has nothing to do with the size of training data used. The root cause is the vocabulary used by the pre-trained model is too large, thus too many parameters cannot be loaded on the ordinary GPU,such as 1080Ti, 2080Ti and so on. There are three options to choose from here:
Most of the words in the large vocabulary used by the original pre-training model are not actually used in the finetune process, so this part of redundant information can be removed. In the NMT model, the embedding matrix actually accounts for the largest proportion of parameters, so our cutting work mainly focuses on reducing the embedding matrix in the pre-trained model. Next I will introduce how to cut the pre-trained model.
Generally speaking, this new model will be much smaller than the original pre-trained model, enough to be loaded onto a normal GPU for training. Note: when finetune on a cropped pretraining model, be sure to generate training data in binary format with the new vocabulary. |
Thanks very much !!! It is so kind of you! |
@fansiawang Is there any code that illustrates your proposed method for cutting the pre-trained model? Thank you. |
Hello Fansia, Thank you for sharing your experience which really helps me a lot. Best, |
@SunbowLiu Here is a script that I wrote to reduce the size of the pre-trained model by pruning the word embeddings for fine-tuning: import argparse
import os
from typing import List
import torch
from fairseq.data import Dictionary
def load_dict(langs: List[str], path: str) -> Dictionary:
d = Dictionary.load(path)
for l in langs:
d.add_symbol(f"[{l}]")
d.add_symbol("<mask>")
return d
def main() -> None:
parser = argparse.ArgumentParser(description="Trims pre-trained mBART model for fine-tuning.")
parser.add_argument("--pre-train-dir", type=str, required=True, help="The pre-trained mBART model directory.")
parser.add_argument("--ft-dict", type=str, required=True, help="The fine-tuning model dictionary.")
parser.add_argument("--langs", type=str, required=True, help="The pre-trained model languages.")
parser.add_argument("--output", type=str, required=True, help="The trimmed mBART model.")
args = parser.parse_args()
langs = args.langs.split(",")
pre_dict = load_dict(langs, os.path.join(args.pre_train_dir, "dict.txt"))
ft_dict = load_dict(langs, args.ft_dict)
data = torch.load(os.path.join(args.pre_train_dir, "model.pt"))
model = data["model"]
mapping: List[int] = []
for i in range(len(ft_dict)):
word = ft_dict[i]
mapping.append(pre_dict.index(word))
for name in ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]:
pre_tensor: torch.Tensor = model[name]
ft_tensor = torch.zeros(
[len(ft_dict), 1024], dtype=pre_tensor.dtype, layout=pre_tensor.layout, device=pre_tensor.device,
)
for ft_i, pre_i in enumerate(mapping):
ft_tensor[ft_i] = pre_tensor[pre_i]
model[name] = ft_tensor
torch.save(data, args.output)
if __name__ == "__main__":
main() Here is an example of how you run it (I saved the script as
|
Hi Damien, Thank you so much for your wonderful code. How about the final translation performance w/ or w/o this pruning technique? I have a concern about the numerical stability after pruning, e.g., the denominator of the final softmax layer becomes much smaller. Maybe it is an easy thing that can be dealt with during fine-tuning? What is your opinion of this concern? Thank you again for your code! Best, |
Using this script, I have tried to reproduce the EN-RO results from the paper, but I only get a BLEU of ~36. It's close, but not quite as good. It certainly could be because of the pruning. I am training on a very different setup (single GPU), so it might be because I haven't got the batch size and learning rate schedule quite right to reproduce the results from the paper. Also, the tokenization steps are kind of tricky for EN-RO, so I might not have gotten that correct. |
@ddaspit Thank you for your reply. Using the pretrained model, I can obtain the same BLEU score (37.8) with the paper. Now I start to fine-tune the model by myself and I will share the new results with you. |
@ddaspit Do you have code to create ft/dict.txt, e.g. for WMT en-ro? The downloads for mbart-cc25 and mbart-enro provide identical |
@sshleifer, just use the default fairseq preprocessor with --joined-dictionary and the fine-tuning corpus. |
@sshleifer Here is the code I used to generate the vocabulary file. import argparse
from glob import glob
from fairseq.data import Dictionary
from fairseq.tokenizer import tokenize_line
def pad_dict(d: Dictionary, num_extra_symbols: int, padding_factor: int = 8) -> None:
i = 0
while (len(d) + num_extra_symbols) % padding_factor != 0:
symbol = f"madeupword{i:04d}"
d.add_symbol(symbol, n=0)
i += 1
def main() -> None:
parser = argparse.ArgumentParser(description="Build vocabulary from corpus data.")
parser.add_argument("--corpus-data", type=str, required=True, help="The path pattern (glob) to all tokenized corpus files (train, test, val).")
parser.add_argument("--langs", type=str, required=True, help="The pre-trained model languages.")
parser.add_argument("--output", type=str, required=True, help="The vocabulary file.")
args = parser.parse_args()
langs = args.langs.split(",")
ft_dict = Dictionary()
for data_path in glob(args.corpus_data):
Dictionary.add_file_to_dictionary(data_path, ft_dict, tokenize_line, 4)
ft_dict.finalize(padding_factor=0)
pad_dict(ft_dict, len(langs) + 1)
ft_dict.save(args.output)
if __name__ == "__main__":
main() Here is an example of how you run it (I saved the script as python build_vocab.py --corpus-data "./ft/*.spm.*" --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN --output ./ft/dict.txt @SunbowLiu Those results look good. |
Very cool. I am trying to trim down the finetuned mbart-en-ro to make it faster/lighter. import sentencepiece as spm
sp_model = spm.SentencePieceProcessor()
# same file as original sentence.bpe.model
sp_model.Load('/home/shleifer/fairseq/enro_trimmed/sentence.bpe.model')
sp_model.encode_as_ids("UN Chief Says There Is No Military Solution in Syria")
=> [8273, 127872, 25915, 6, 8621, 2070, 437, 67484, 52, 187894, 22, 51711] |
@sshleifer I haven't needed to do anything like that. I just use
|
Does anyone have a link to the enro data? I tried to rerun finetuning on a new machine and I got
I have 610,319 lines 14,280,744 words 98,251,174 chars in |
http://data.statmt.org/wmt16/translation-task/training-parallel-ep-v8.tgz |
Hi Guys, I am new to machine translation and need your help to understand the process of using pretrained mBART model for fine-tuning of EN-DE machine translation task. I just wanted to make sure whether my understanding of processes needed to be followed are correct or not in order to acheive my goal.
DICT=dict.txt
PRETRAIN= path_of_trimmed_bart_pretrained_model fairseq-train ./ft
model_dir=MBART_finetuned_deen fairseq-generate path_2_data Thanks in advance for helping me out! |
@kr-sundaram Everything looks correct. You definitely want to use the sentencepiece model that was used for the pretrained model when tokenizing and detokenizing your data. Your dictionary file looks good. Regarding the trimmed model size, 5GB sounds correct. That was around the size I got with a similar vocabulary size. For the |
Thanks @ddaspit for your reply!! Hey, I have got some more doubts.
a.
Error received
b. Error received same as what i got after executing command-a c. Error is same as stated for command-a but, this time, with having some has more additional information than what i got after execution of command-a&b. RuntimeError: Error(s) in loading state_dict for BARTModel: Can you help me understanding the issue and rectify it and which one amongst above 3 should be used to generate translate of test set? This error seems to be confusing. as i had run only to check if the code works or not, and when i found its working fine, I interrupted the model training in the middle. Do you think, it might be because of stopping the execution in the middle. For me its doesn't make sense as my understanding is size of the checkpoints and current model should be same throughout the model training. Kindly clarify if i am wrong!
Thank you! |
fairseq-generate $CORPUS/preprocessed \
--path $CORPUS/checkpoints/checkpoint_best.pt \
--task translation_from_pretrained_bart \
--gen-subset test \
-s $SRC \
-t $TRG \
--remove-bpe sentencepiece \
--sacrebleu \
--max-sentences 32 \
--langs $LANGS > $CORPUS/results |
Thanks @ddaspit for the reply!! As you said in point-1, in my vocabulary i have 23506 tokens, then i used Is there any way for reducing the model size much than what i followed? Or what size i have now is the min size i could get ever? Finally, with your kind help and guidance, I was able to fine-tune with small DE -> EN dataset for test run and I wanted to run
Could you please share your command that you executed for I just ran for 1 epoch to check if
|
@kr-sundaram I haven't used This is how you would interpret the results from
The BLEU score should be at the end of the output. Look for a line like this
|
@fansiawang hi ,do you have solved your question ? I have the same question as yours:
|
Thanks to @ddaspit for the awesome scripts to trim mBART for low resource GPUs. I run the script to trim the model 'mbart-large-cc25'. |
@learnercat The script was designed to work with the pre-trained model files linked from the mBART README in the fairseq repo. I haven't tried it with the models files from huggingface. I am guessing that the models files from huggingface have a different internal structure. |
fairseq-generate ~/CM/mt_enghinglish/transliterated_normalized/destdir/ --path checkpoint_best.pt -s en_XX -t hi_IN --beam 5 --batch-size 32 --task translation_from_pretrained_bart --gen-subset test --bpe "sentencepiece" --sentencepiece-model ../mbart.cc25.v2/sentence.bpe.model --remove-bpe "sentencepiece" --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN I am using this command to fine tune EN-HI model but I am getting the size mismatch error: Size of dict.txt created from the build_vocab scripts is 47330. And in the pruned model the size of the embedding matrix from encoder is - torch.Size([47360, 1024]). If the vocab size is 47330 then why does embedding matrix has 47360 rows ? |
Is there any code released for pretraining mbart ? |
i have the same question |
hello, do you have any progress in the pretraining of mbart? @fansiawang |
I'm running into this same issue. Were you able to figure out what the problem was? @akshayg08 |
@IamAdiSri, I guess you are having a similar problem to the embedding size mismatch issue discussed here (see the whole thread, anyway) for the similar mBART50 and MM100 models. In your case, mBART supports 25 languages, and if you add 47330+25 and round up to the closest multiple of 8 (for some efficiency reasons, dictionary sizes are multiple of 8) you get 47360. Probably, you will be able to move on if you manually add mBART's language codes to the dictionary followed by 5 extra symbols such as "madeupword01", "madeupword02", etc. According to mBART's code, these languages are: ar_AR, cs_CZ, de_DE, en_XX, es_XX, et_EE, fi_FI, fr_XX, gu_IN, hi_IN, it_IT, ja_XX, kk_KZ, ko_KR, lt_LT, lv_LV, my_MM, ne_NP, nl_XX, ro_RO, ru_RU, si_LK, tr_TR, vi_VN, zh_CN. |
@ddaspit Thanks a lot for the scripts and advice and to the rest for the great discussion, which helped me out with some questions. |
@ddaspit Hi ddaspit, I'm following the instructions you provided above and it's super helpful so thank you so much. But I have a few questions.
fairseq-train /srv/scratch3/ltian/output
|
❓ Questions and Help
Thanks for releasing the mbart models! Referring to #1758 I reproduced the same results, which is basically close to the results of the paper. Therefore, I use my own data (Japanese-Korean) to fine-tune on the pre-trained model mbart.cc25. Compared with the results of binglingual (baseline: 6.01 BLEU-SBP), finetune has improved by 2.64 points (fintune: 9.65 BLEU-SBP).
Although I got good results, the absolute quality is still relatively low. Due to the pre-training vocabulary is very large, the model after finetuning is very large. I hope you can provide a tutorial of how to pretrain mBART on our own data.
What is your question?
The text was updated successfully, but these errors were encountered: