Skip to content

Pretraining model examples

zhezhaoa edited this page Aug 11, 2023 · 35 revisions

UER-py allows users to combine different modules (e.g. embedding, encoder, and target module) to construct pre-training models. Here are some examples of trying different combinations to implement frequently-used pre-training models. In most cases, configuration file specifies the pre-training modules. In this section, we explicitly specify the module in command line. The modules specified in command line can overwrite those specified in configuration file.

RoBERTa

The example of pre-processing and pre-training for RoBERTa:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --dynamic_masking \
                      --data_processor mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 1e-4 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm --tie_weights

RoBERTa uses dynamic masking, mlm target, and allows a sample to contain contents from multiple documents. We don't recommend to use --full_sentences when the document is short (e.g. reviews). Notice that RoBERTa removes NSP target. The corpus for RoBERTa stores one document per line, which is different from corpus used by BERT. In addition, --data_processor specified in pre-training stage should be the same with pre-processing stage. RoBERTa can load BERT models for incremental pre-training (and vice versa). The example of doing incremental pre-training upon existing BERT model:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --dynamic_masking \
                      --data_processor mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/google_zh_model.bin \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 2e-5 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm --tie_weights

ALBERT

The example of pre-processing and pre-training for ALBERT:

python3 preprocess.py --corpus_path corpora/book_review_bert.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --dynamic_masking \
                      --data_processor albert

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/albert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 5e-5 \
                    --data_processor albert \
                    --embedding word pos seg --factorized_embedding_parameterization \
                    --encoder transformer --mask fully_visible --parameter_sharing \
                    --target mlm sp

The corpus format of ALBERT is the identical with BERT.
--data_processor albert denotes that using ALBERT dataset format.
--target mlm sp denotes that using ALBERT target, which consists of mlm and sp (sentence order prediction) targets.
--factorized_embedding_parameterization denotes that using factorized embedding parameterization to untie the embedding size from the hidden layer size.
--parameter_sharing denotes that sharing all parameters (including feed-forward and attention parameters) across layers.
we provide 4 configuration files for ALBERT model in models/albert folder, base_config.json , large_config.json , xlarge_config.json , xxlarge_config.json .
The example of doing incremental pre-training upon Google's ALBERT pre-trained models of different sizes:

python3 preprocess.py --corpus_path corpora/book_review_bert.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --dynamic_masking \
                      --data_processor albert

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/google_zh_albert_base_model.bin \
                    --config_path models/albert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 2e-5 \
                    --data_processor albert \
                    --embedding word pos seg --factorized_embedding_parameterization \
                    --encoder transformer --mask fully_visible --parameter_sharing \
                    --target mlm sp

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/google_zh_albert_xlarge_model.bin \
                    --config_path models/albert/xlarge_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 2e-5 \
                    --data_processor albert \
                    --embedding word pos seg --factorized_embedding_parameterization \
                    --encoder transformer --mask fully_visible --parameter_sharing \
                    --target mlm sp

SpanBERT

SpanBERT introduces span masking and span boundary objective. We only consider span masking here. NSP target is removed by SpanBERT. The example of pre-processing and pre-training for SpanBERT (static masking):

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 \
                      --dup_factor 20 --span_masking --span_geo_prob 0.3 --span_max_length 5 \
                      --data_processor mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7  --learning_rate 1e-4 \
                    --total_steps 10000 --save_checkpoint_steps 5000 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm --tie_weights

--dup_factor specifies the number of times to duplicate the input data (with different masks). The default value is 5 . The example of pre-processing and pre-training for SpanBERT (dynamic masking):

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 \
                      --dynamic_masking --data_processor mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7  --learning_rate 1e-4 \
                    --span_masking --span_geo_prob 0.3 --span_max_length 5 \
                    --total_steps 10000 --save_checkpoint_steps 5000 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm --tie_weights

BERT-WWM

BERT-WWM introduces whole word masking. MLM target is used here. The example of pre-processing and pre-training for BERT-WWM (static masking):

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 \
                      --dup_factor 20 --whole_word_masking \
                      --data_processor mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7  --learning_rate 1e-4 \
                    --total_steps 10000 --save_checkpoint_steps 5000 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm --tie_weights

--whole_word_masking denotes that whole word masking is used.
The example of pre-processing and pre-training for BERT-WMM (dynamic masking):

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 \
                      --dynamic_masking --data_processor mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7  --learning_rate 1e-4 \
                    --whole_word_masking \
                    --total_steps 10000 --save_checkpoint_steps 5000 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm --tie_weights

BERT-WMM implemented in UER-py is only applicable to Chinese. jieba is used as word segmentation tool (see uer/utils/data.py):

import jieba
wordlist = jieba.cut(sentence)

One can change the code in uer/utils/data.py to substitute jieba for other word segmentation tools.

GPT

The example of pre-processing and pre-training for GPT:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor lm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/gpt2/config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 1e-4 \
                    --data_processor lm \
                    --embedding word pos --encoder transformer --mask causal --target lm --tie_weights

The corpus format of GPT is the identical with RoBERTa. We can pre-train GPT through --embedding word pos --encoder transformer --mask causal --target lm --tie_weights . In addition, we should modify the models/gpt2/config.json , removing "remove_embedding_layernorm": true and "remove_embedding_layernorm": true options.

GPT-2

The example of pre-processing and pre-training for GPT-2:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor lm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/gpt2/config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 1e-4 \
                    --data_processor lm \
                    --embedding word pos --remove_embedding_layernorm \
                    --encoder transformer --mask causal --layernorm_positioning pre \
                    --target lm --tie_weights

The corpus format of GPT-2 is the identical with GPT and RoBERTa. Notice that the encoder of GPT-2 is different from the encoder of GPT. The layer normalization is moved to the input of each sub-block (--layernorm_positioning pre) and an additional layer normalization is added after the final block. The layer normalization after embedding layer should be removed (--remove_embedding_layernorm).

ELMo

The example of pre-processing and pre-training for ELMo:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor bilm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt  \
                    --config_path models/rnn/bilstm_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 5e-4 \
                    --data_processor bilm \
                    --embedding word --remove_embedding_layernorm --encoder bilstm --target bilm

The corpus format of ELMo is identical with GPT-2. We can pre-train ELMo through --embedding word --remove_embedding_layernorm--encoder bilstm and --target bilm . --embedding word denotes using traditional word embedding. LSTM does not require position embedding. In addition, we specify --remove_embedding_layernorm and the layernorm after word embedding is removed.

T5

T5 proposes to use seq2seq model to unify NLU and NLG tasks. With extensive experiments, T5 recommends to use encoder-decoder architecture and BERT-style objective function (the model predicts the masked words). The example of using T5 for pre-training:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_with_sentinel_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 128 \
                      --dynamic_masking --data_processor t5

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_with_sentinel_vocab.txt \
                    --config_path models/t5/small_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --learning_rate 1e-3 --batch_size 64 \
                    --span_masking --span_geo_prob 0.3 --span_max_length 5 \
                    --data_processor t5 \
                    --embedding word --relative_position_embedding --remove_embedding_layernorm --tgt_embedding word --share_embedding \
                    --encoder transformer --mask fully_visible --decoder transformer \
                    --layernorm_positioning pre --layernorm t5 --remove_attention_scale --remove_transformer_bias \
                    --target lm --tie_weights

The corpus format of T5 is identical with GPT-2. --relative_position_embedding denotes using relative position embedding. --remove_embedding_layernorm and --layernorm_positioning pre denote that pre-layernorm is used (same with GPT-2). Since T5 uses encoder-decoder architecture, we have to specify --encoder and --decoder.

T5-v1_1

T5-v1_1 includes several improvements compared to the original T5 model. The example of using T5-v1_1 for pre-training:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_with_sentinel_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 128 \
                      --dynamic_masking --data_processor t5

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_with_sentinel_vocab.txt \
                    --config_path models/t5-v1_1/small_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --learning_rate 1e-3 --batch_size 64 \
                    --span_masking --span_geo_prob 0.3 --span_max_length 5 \
                    --data_processor t5 \
                    --embedding word --relative_position_embedding --remove_embedding_layernorm --tgt_embedding word --share_embedding \
                    --encoder transformer --mask fully_visible --decoder transformer \
                    --layernorm_positioning pre --layernorm t5 --feed_forward gated --remove_attention_scale --remove_transformer_bias \
                    --target lm

The corpus format of T5-v1_1 is identical with T5. --feed_forward denotes the type of feed-forward layer. --tie_weights is removed and there is no parameter sharing between the embedding layer and the layer before the softmax. T5-v1_1 and T5 have different configuration files.

PEGASUS

PEGASUS proposes to use GSG (gap sentence generation) pre-training target. GSG target aims to predict the sentences extracted from the document, which is beneficial to text summarization task. The example of using PEGASUS for pre-training:

python3  preprocess.py --corpus_path corpora/CLUECorpusSmall_bert_sampled.txt --vocab_path models/google_zh_vocab.txt \
                       --dataset_path dataset.pt --processes_num 8 --seq_length 512 --tgt_seq_length 256 \
                       --dup_factor 1 --data_processor gsg --sentence_selection_strategy lead

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/pegasus/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --learning_rate 1e-4 --batch_size 8 \
                    --data_processor gsg \
                    --embedding word sinusoidalpos --tgt_embedding word sinusoidalpos --remove_embedding_layernorm --share_embedding \
                    --encoder transformer --mask fully_visible --layernorm_positioning pre --decoder transformer \
                    --target lm --has_lmtarget_bias --tie_weights

The corpus format of PEGASUS is identical with BERT. In pre-processing stage, --sentence_selection_strategy denotes the strategy for sentence selection in PEGASUS. When random sentence selection is used (--sentence_selection_strategy random), one can use --dup_factor to specify the number of times to duplicate the input data (with different masks on sentence). When --sentence_selection_strategy lead is specified, --dup_factor should be set to 1.

BART

BART proposes to use seq2seq model reconstruct the corrupted document. The encoder handles the corrupted document and the decoder reconstruct it. BART explores different corruption strategies and recommends to use the combination of sentence permutation and text infilling (using a single MASK token to mask consecutive tokens). The example of using BART for pre-training:

python3 preprocess.py --corpus_path corpora/CLUECorpusSmall_bert_sampled.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 512 --data_processor bart

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bart/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --learning_rate 1e-4 --batch_size 8 \
                    --span_masking --span_geo_prob 0.3 --span_max_length 5 \
                    --data_processor bart \
                    --embedding word pos --tgt_embedding word pos --share_embedding \
                    --encoder transformer --mask fully_visible --decoder transformer \
                    --target lm --tie_weights --has_lmtarget_bias

XLM-RoBERTa

We download multi-lingual pre-trained models XLM-RoBERTa-base, XLM-RoBERTa-large and do further pre-training upon them. Take XLM-RoBERTa-base as an example, we firstly convert the pre-trained model into UER format:

python3 scripts/convert_xlmroberta_from_huggingface_to_uer.py --input_model_path models/xlmroberta_base_model_huggingface.bin \
                                                                          --output_model_path models/xlmroberta_base_model_uer.bin \
                                                                          --layers_num 12

Since the special tokens used in original pre-trained XLM-RoBERTa model are different from the ones used in BERT, we need to change the path of special tokens mapping file in uer/utils/constants.py from models/special_tokens_map.json to models/xlmroberta_special_tokens_map.json. Then we do further pre-train upon the XLM-RoBERTa-base model:

python3 preprocess.py --corpus_path corpora/book_review.txt \
                      --spm_model_path models/xlmroberta_spm.model --tokenizer xlmroberta \
                      --dataset_path xlmroberta_zh_dataset.pt --processes_num 8 --seq_length 128 --dynamic_masking \
                      --data_processor mlm

python3 pretrain.py --dataset_path xlmroberta_zh_dataset.pt --spm_model_path models/xlmroberta_spm.model --tokenizer xlmroberta \
                    --pretrained_model_path models/xlmroberta_base_model_uer.bin \
                    --config_path models/xlm-roberta/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --batch_size 8 \
                    --total_steps 100000 --save_checkpoint_steps 10000 --report_steps 100 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm

Compared with commonly used BERT and RoBERTa models, original XLM-RoBERTa uses different tokenization strategy (--tokenizer xlmroberta --spm_model_path models/xlmroberta_spm.model) and special tokens mapping file.

Prefix LM

The example of using prefix LM for pre-training (which is used in UniLM):

python3 preprocess.py --corpus_path corpora/csl_abstract_title.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 256 --data_processor prefixlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 1e-4 \
                    --total_steps 5000 --save_checkpoint_steps 1000 \
                    --data_processor prefixlm \
                    --embedding word pos seg --encoder transformer --mask causal_with_prefix --target lm --tie_weights

csl_abstract_title.txt is a Chinese scientific literature corpus. The original data can be found here. The abstract and title sequences are separated by \t , which is the corpus format of --data_processor prefixlm . We can pre-train prefix LM model through --mask causal_with_prefix and --target prefixlm. Notice that the model uses the segment information to determine which part is prefix. Therefore we have to use --embedding word pos seg.

RealFormer

RealFormer proposes to use residual attention to achieve better performance with less pre-training budget. The example of using RealFormer for pre-training:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --dynamic_masking \
                      --data_processor mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 1e-4 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --has_residual_attention --target mlm --tie_weights

--has_residual_attention is used to denote using residual attention in Transformer encoder.

More combinations

The example of using LSTM encoder and LM target for pre-training:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor lm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/rnn/lstm_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 20000 --save_checkpoint_steps 5000 --learning_rate 1e-3 \
                    --data_processor lm \
                    --embedding word --remove_embedding_layernorm --encoder lstm --target lm

We use the models/rnn/lstm_config.json as configuration file.

The example of using GRU encoder and LM target for pre-training:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor lm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/rnn/gru_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 1e-3 \
                    --total_steps 20000 --save_checkpoint_steps 5000 \
                    --data_processor lm \
                    --embedding word --remove_embedding_layernorm --encoder gru --target lm

The example of using GatedCNN encoder and LM target for pre-training:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor lm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/cnn/gatedcnn_9_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 1e-4 \
                    --total_steps 20000 --save_checkpoint_steps 5000 \
                    --data_processor lm \
                    --embedding word --remove_embedding_layernorm --encoder gatedcnn --target lm

The example of using machine translation for pre-training (the objective is the same with CoVe but the Transformer encoder and decoder are used):

python3 preprocess.py --corpus_path corpora/news-commentary-v13-zh-en.txt \
                      --vocab_path models/google_zh_vocab.txt --tgt_vocab_path models/google_uncased_en_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 64 --tgt_seq_length 64 \
                      --data_processor mt

python3 pretrain.py --dataset_path dataset.pt \
                    --vocab_path models/google_zh_vocab.txt --tgt_vocab_path models/google_uncased_en_vocab.txt \
                    --config_path models/transformer/base_config.json --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 --learning_rate 1e-4 \
                    --total_steps 50000 --save_checkpoint_steps 10000 --report_steps 1000 \
                    --data_processor mt \
                    --embedding word sinusoidalpos --tgt_embedding word sinusoidalpos \
                    --encoder transformer --mask fully_visible --decoder transformer \
                    --target lm

news-commentary-v13-zh-en.txt is a Chinese-English parallel corpus (See pretraining data section for more information). The source and target sequences are separated by \t , which is the corpus format of --data_processor mt . The pre-trained encoder can be used for downstream tasks.

The example of using Transformer encoder and classification (CLS) target for pre-training:

python3 preprocess.py --corpus_path corpora/book_review_cls.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor cls

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 2000 --save_checkpoint_steps 1000 --learning_rate 2e-5 \
                    --data_processor cls \
                    --embedding word pos seg --encoder transformer --mask fully_visible \
                    --pooling first --target cls --labels_num 2

Notice that we need to explicitly specify the number of labels by --labels_num. The format of the corpus for classification target is as follows (text and text pair classification):

1        instance1
0        instance2
1        instance3
1        instance1_text_a        instance1_text_b
0        instance2_text_a        instance1_text_b
1        instance3_text_a        instance1_text_b

\t is used to separate different columns (see book_review_cls.txt in corpora folder).

The example of using LSTM encoder and classification (CLS) target for pre-training:

python3 preprocess.py --corpus_path corpora/book_review_cls.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor cls

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/rnn/lstm_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 2000 --save_checkpoint_steps 1000 --learning_rate 1e-3 \
                    --data_processor cls \
                    --embedding word --remove_embedding_layernorm --encoder lstm \
                    --pooling max --target cls --labels_num 2

For the classification corpus, it is beneficial to use the joint of classification and MLM targets. The example of using Transformer encoder and classification (CLS) + MLM targets for pre-training:

python3 preprocess.py --corpus_path corpora/book_review_cls.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --dynamic_masking \
                      --data_processor cls_mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 5000 --save_checkpoint_steps 1000 --learning_rate 2e-5 \
                    --data_processor cls_mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible \
                    --pooling first --target cls mlm --labels_num 2
Clone this wiki locally