Skip to content

Quickstart

zhezhaoa edited this page Mar 14, 2024 · 44 revisions

Pre-training and text classification with BERT

This section uses several commonly-used examples to demonstrate how to use UER-py. More details are discussed in Instructions section. We firstly use BERT (a text pre-training model) on book review sentiment classification dataset. We pre-train model on book review corpus and then fine-tune it on book review sentiment classification dataset. There are three input files: book review corpus, book review sentiment classification dataset, and vocabulary. All files are encoded in UTF-8 and included in this project.

The format of the corpus for BERT is as follows (one sentence per line and documents are delimited by empty lines):

doc1-sent1
doc1-sent2
doc1-sent3

doc2-sent1

doc3-sent1
doc3-sent2

The book review corpus is obtained from book review classification dataset. We remove labels and split a review into two parts from the middle to construct a document with two sentences (see book_review_bert.txt in corpora folder).

The format of the classification dataset is as follows:

label    text_a
1        instance1
0        instance2
1        instance3

Label and instance are separated by \t . The first row is a list of column names. The label ID should be an integer between (and including) 0 and n-1 for n-way classification.

We use Google's Chinese vocabulary file models/google_zh_vocab.txt, which contains 21128 Chinese characters.

We firstly pre-process the book review corpus. In the pre-processing stage, the corpus needs to be processed into the format required by the specified pre-training model (--data_processor):

python3 preprocess.py --corpus_path corpora/book_review_bert.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor bert

Notice that six>=1.12.0 is required.

Pre-processing is time-consuming. Using multiple processes can largely accelerate the pre-processing speed (--processes_num). BERT tokenizer is used in default (--tokenizer bert). After pre-processing, the raw text is converted to dataset.pt, which is the input of pretrain.py. Then we download Google's pre-trained Chinese BERT model google_zh_model.bin (in UER format and the original model is from here), and put it in models folder. We load the pre-trained Chinese BERT model and further pre-train it on book review corpus. Pre-training model is usually composed of embedding, encoder, and target layers. To build a pre-training model, we should provide related information. Configuration file (--config_path) specifies the modules and hyper-parameters used by pre-training models. More details can be found in models/bert/base_config.json. Suppose we have a machine with 8 GPUs:

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/google_zh_model.bin \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/book_review_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 5000 --save_checkpoint_steps 1000 --batch_size 32

mv models/book_review_model.bin-5000 models/book_review_model.bin

Notice that the model trained by pretrain.py is attacted with the suffix which records the training step (--total_steps). We could remove the suffix for ease of use.

Then we fine-tune the pre-trained model on downstream classification dataset. We use embedding and encoder layers of book_review_model.bin, which is the output of pretrain.py:

python3 finetune/run_classifier.py --pretrained_model_path models/book_review_model.bin \
                                   --vocab_path models/google_zh_vocab.txt \
                                   --config_path models/bert/base_config.json \
                                   --train_path datasets/book_review/train.tsv \
                                   --dev_path datasets/book_review/dev.tsv \
                                   --test_path datasets/book_review/test.tsv \
                                   --epochs_num 3 --batch_size 32

The default path of the fine-tuned classifier model is models/finetuned_model.bin . It is noticeable that the actual batch size of pre-training is --batch_size times --world_size ; The actual batch size of downstream task (e.g. classification) is --batch_size . Then we do inference with the fine-tuned model.

python3 inference/run_classifier_infer.py --load_model_path models/finetuned_model.bin \
                                          --vocab_path models/google_zh_vocab.txt \
                                          --config_path models/bert/base_config.json \
                                          --test_path datasets/book_review/test_nolabel.tsv \
                                          --prediction_path datasets/book_review/prediction.tsv \
                                          --labels_num 2

--test_path specifies the path of the file to be predicted. The file should contain text_a column. --prediction_path specifies the path of the file with prediction results. We need to explicitly specify the number of labels by --labels_num. The above dataset is a two-way classification dataset.

We can also use google_zh_model.bin and fine-tune it on downstream classification dataset:

python3 finetune/run_classifier.py --pretrained_model_path models/google_zh_model.bin \
                                   --vocab_path models/google_zh_vocab.txt \
                                   --config_path models/bert/base_config.json \
                                   --train_path datasets/book_review/train.tsv \
                                   --dev_path datasets/book_review/dev.tsv \
                                   --test_path datasets/book_review/test.tsv \
                                   --epochs_num 3 --batch_size 32

As mentioned above, the information related to the model is generally placed in the configuration file. Here we further introduce the BERT-base configuration file models/bert/base_config.json :

{
  "emb_size": 768,
  "feedforward_size": 3072,
  "hidden_size": 768,
  "hidden_act": "gelu",
  "heads_num": 12,
  "layers_num": 12,
  "max_seq_length": 512,
  "dropout": 0.1,
  "data_processor": "bert"
  "embedding": ["word", "pos", "seg"],
  "encoder": "transformer",
  "mask": "fully_visible",
  "target": ["mlm", "sp"],
  "tie_weights": true
}

The embedding layer of BERT is the sum of word (token), position, and segment embeddings and therefore "embedding": ["word", "pos", "seg"] is used. BERT uses Transformer encoder ("encoder": "transformer"). Since the word token can attend to all tokens, we use fully_visible mask type ("mask": "fully_visible"). In terms of target, BERT uses Masked LM and Next Sentence Prediction ("target": ["mlm", "sp"]). "tie_weights": true specifies that parameters are shared between the embedding layer and the layer before the softmax. The format of the dataset.pt should be specified in pre-training stage ("data_processor": "bert") and should be coincident with the format in preprocess stage (--data_processor bert).


Specifying which GPUs are used

We recommend to use CUDA_VISIBLE_DEVICES to specify which GPUs are visible (all GPUs are used in default). Suppose GPU 0 and GPU 2 are available:

python3 preprocess.py --corpus_path corpora/book_review_bert.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --data_processor bert

CUDA_VISIBLE_DEVICES=0,2 python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                                             --pretrained_model_path models/google_zh_model.bin \
                                             --config_path models/bert/base_config.json \
                                             --output_model_path models/book_review_model.bin \
                                             --world_size 2 --gpu_ranks 0 1 \
                                             --total_steps 5000 --save_checkpoint_steps 1000 --batch_size 32

mv models/book_review_model.bin-5000 models/book_review_model.bin

CUDA_VISIBLE_DEVICES=0,2 python3 finetune/run_classifier.py --pretrained_model_path models/book_review_model.bin \
                                                            --vocab_path models/google_zh_vocab.txt \
                                                            --config_path models/bert/base_config.json \
                                                            --train_path datasets/book_review/train.tsv \
                                                            --dev_path datasets/book_review/dev.tsv \
                                                            --test_path datasets/book_review/test.tsv \
                                                            --output_model_path models/classifier_model.bin \
                                                            --epochs_num 3 --batch_size 32

CUDA_VISIBLE_DEVICES=0,2 python3 inference/run_classifier_infer.py --load_model_path models/classifier_model.bin \
                                                                   --vocab_path models/google_zh_vocab.txt \
                                                                   --config_path models/bert/base_config.json \
                                                                   --test_path datasets/book_review/test_nolabel.tsv \
                                                                   --prediction_path datasets/book_review/prediction.tsv \
                                                                   --labels_num 2

Notice that we explicitly specify the fine-tuned model path by --output_model_path in fine-tuning stage.


Pre-training with MLM target

BERT consists of Next Sentence Prediction (NSP) target. However, NSP target is not suitable for sentence-level reviews since we have to split a sentence into multiple parts to construct document. UER-py facilitates the use of different targets. Using Masked Language Modeling (MLM) as target (discard NSP target) could be a properer choice for pre-training of reviews:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --dynamic_masking --data_processor mlm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/google_zh_model.bin \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/book_review_mlm_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 5000 --save_checkpoint_steps 1000 --batch_size 32 \
                    --data_processor mlm --target mlm

mv models/book_review_mlm_model.bin-5000 models/book_review_mlm_model.bin

CUDA_VISIBLE_DEVICES=0,1 python3 finetune/run_classifier.py --pretrained_model_path models/book_review_mlm_model.bin \
                                                            --vocab_path models/google_zh_vocab.txt \
                                                            --config_path models/bert/base_config.json \
                                                            --train_path datasets/book_review/train.tsv \
                                                            --dev_path datasets/book_review/dev.tsv \
                                                            --test_path datasets/book_review/test.tsv \
                                                            --epochs_num 3 --batch_size 32

In pre-training stage, we still use BERT-base configuration file models/bert/base_config.json . Since target is changed, we specify --data_processor mlm --target mlm in command line. There are multiple ways to specify the hyper-parameters and their priority order is as follows: command line, configuration file, and default setting. The hyper-parameters specified in command line can overwrite those specified in configuration file and default setting. Dynamic masking strategy (--dynamic_masking) is used in pre-training stage. Dynamic masking decreases the size of dataset.pt and usually benefits the performance. It is recommended to use dynamic masking. Different targets require different corpus formats. The format of the corpus for MLM target is as follows (one document per line):

doc1
doc2
doc3

Notice that corpora/book_review.txt (instead of corpora/book_review_bert.txt) is used when the target is switched to MLM.


Downstream task fine-tuning with BERT

Besides classification, UER-py also supports other downstream tasks. Various datasets can be downloaded in Downstream datasets section. We respectively download LCQMC, MSRA-NER, CMRC2018 datasets and put them in datasets folder.

We could use run_classifier.py for text pair classification:

python3 finetune/run_classifier.py --pretrained_model_path models/google_zh_model.bin \
                                   --vocab_path models/google_zh_vocab.txt \
                                   --config_path models/bert/base_config.json \
                                   --train_path datasets/lcqmc/train.tsv \
                                   --dev_path datasets/lcqmc/dev.tsv \
                                   --test_path datasets/lcqmc/test.tsv \
                                   --output_model_path models/classifier_model.bin \
                                   --batch_size 32 --epochs_num 3 --seq_length 128

For text pair classification, the dataset should contain text_a, text_b, and label columns. Then we do inference with the fine-tuned text pair classification model:

python3 inference/run_classifier_infer.py --load_model_path models/classifier_model.bin \
                                          --vocab_path models/google_zh_vocab.txt \
                                          --config_path models/bert/base_config.json \
                                          --test_path datasets/lcqmc/test.tsv \
                                          --prediction_path datasets/lcqmc/prediction.tsv \
                                          --seq_length 128 --labels_num 2

The file to be predicted (--test_path) should contain text_a and text_b columns.

We could use run_ner.py for named entity recognition:

python3 finetune/run_ner.py --pretrained_model_path models/google_zh_model.bin \
                            --vocab_path models/google_zh_vocab.txt \
                            --config_path models/bert/base_config.json \
                            --train_path datasets/msra_ner/train.tsv \
                            --dev_path datasets/msra_ner/dev.tsv \
                            --test_path datasets/msra_ner/test.tsv \
                            --output_model_path models/ner_model.bin \
                            --label2id_path datasets/msra_ner/label2id.json \
                            --epochs_num 5 --batch_size 16

--label2id_path specifies the path of label2id file for named entity recognition. Then we do inference with the fine-tuned NER model:

python3 inference/run_ner_infer.py --load_model_path models/ner_model.bin \
                                   --vocab_path models/google_zh_vocab.txt \
                                   --config_path models/bert/base_config.json \
                                   --test_path datasets/msra_ner/test_nolabel.tsv \
                                   --prediction_path datasets/msra_ner/prediction.tsv \
                                   --label2id_path datasets/msra_ner/label2id.json

We could use run_cmrc.py for machine reading comprehension:

python3 finetune/run_cmrc.py --pretrained_model_path models/google_zh_model.bin \
                             --vocab_path models/google_zh_vocab.txt \
                             --config_path models/bert/base_config.json \
                             --train_path datasets/cmrc2018/train.json --dev_path datasets/cmrc2018/dev.json \
                             --output_model_path models/cmrc_model.bin \
                             --epochs_num 2 --batch_size 8 --seq_length 512

We don't specify the --test_path because CMRC2018 dataset doesn't provide labels for testset. Then we do inference with the fine-tuned MRC model:

python3 inference/run_cmrc_infer.py --load_model_path models/cmrc_model.bin \
                                    --vocab_path models/google_zh_vocab.txt \
                                    --config_path models/bert/base_config.json \
                                    --test_path datasets/cmrc2018/test.json \
                                    --prediction_path datasets/cmrc2018/prediction.json \
                                    --seq_length 512

Cross validation for classification

UER-py supports cross validation for classification. The example of using cross validation on SMP2020-EWECT, a competition dataset:

CUDA_VISIBLE_DEVICES=0 python3 finetune/run_classifier_cv.py --pretrained_model_path models/google_zh_model.bin \
                                                             --vocab_path models/google_zh_vocab.txt \
                                                             --config_path models/bert/base_config.json \
                                                             --train_path datasets/smp2020-ewect/virus/train.tsv \
                                                             --train_features_path datasets/smp2020-ewect/virus/train_features.npy \
                                                             --output_model_path models/classifier_model.bin \
                                                             --epochs_num 3 --batch_size 32 --folds_num 5

The results of google_zh_model.bin are 79.1/63.8 (Accuracy/Marco F1). --folds_num specifies the number of rounds of cross-validation. --output_path specifies the path of the fine-tuned model. --folds_num models are saved and the fold ID suffix is added to the model's name. --train_features_path specifies the path of out-of-fold (OOF) predictions. run_classifier_cv.py generates probabilities over classes on each fold by training a model on the other folds in the dataset. train_features.npy can be used as features for stacking. More details are introduced in Competition solutions section.

We can further try different pre-trained models. For example, we download RoBERTa-wwm-ext-large from HIT and convert it into UER format:

python3 scripts/convert_bert_from_huggingface_to_uer.py --input_model_path models/chinese_roberta_wwm_large_ext_pytorch/pytorch_model.bin \
                                                        --output_model_path models/chinese_roberta_wwm_large_ext_pytorch/pytorch_uer_model.bin \
                                                        --layers_num 24

CUDA_VISIBLE_DEVICES=0,1 python3 finetune/run_classifier_cv.py --pretrained_model_path models/chinese_roberta_wwm_large_ext_pytorch/pytorch_uer_model.bin \
                                                               --vocab_path models/google_zh_vocab.txt \
                                                               --config_path models/bert/large_config.json \
                                                               --train_path datasets/smp2020-ewect/virus/train.tsv \
                                                               --train_features_path datasets/smp2020-ewect/virus/train_features.npy \
                                                               --output_model_path models/classifier_model.bin \
                                                               --folds_num 5 --epochs_num 3 --batch_size 32

The results of RoBERTa-wwm-ext-large are 80.3/66.8 (Accuracy/Marco F1). Notice that models/bert/large_config.json configuration file is used. For ease of use, we have converted many pre-trained weights from other organizations into UER format. One can see Modelzoo for more information.

The example of using our review-corpus RoBERTa-large pre-trained model:

CUDA_VISIBLE_DEVICES=0,1 python3 finetune/run_classifier_cv.py --pretrained_model_path models/review_roberta_large_model.bin \
                                                               --vocab_path models/google_zh_vocab.txt \
                                                               --config_path models/bert/large_config.json \
                                                               --train_path datasets/smp2020-ewect/virus/train.tsv \
                                                               --train_features_path datasets/smp2020-ewect/virus/train_features.npy \
                                                               --output_model_path models/classifier_model.bin \
                                                               --folds_num 5 --learning_rate 1e-5 --epochs_num 3 --batch_size 32 --seed 11

The results are 81.3/68.4 (Accuracy/Marco F1), which are very competitive compared with other open-source pre-trained models. The corpus used by the above pre-trained model is highly similar with SMP2020-EWECT, a Weibo review dataset.
Sometimes large model does not converge. We need to try different random seeds by specifying --seed.


Using more encoders besides Transformer

UER-py supports encoders besides Transformer. We select a 2-layers LSTM encoder to substitute 12-layers Transformer encoder. We firstly download cluecorpussmall_lstm_lm_model.bin for 2-layers LSTM encoder. The model is pre-trained on CLUECorpusSmall corpus for 500,000 steps:

python3 preprocess.py --corpus_path corpora/cluecorpussmall.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 256 --data_processor lm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/rnn/lstm_config.json \
                    --output_model_path models/cluecorpussmall_lstm_lm_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 500000 --save_checkpoint_steps 100000 \
                    --learning_rate 1e-3 --batch_size 64

Then we remove the training step suffix of pre-trained model and fine-tune it on downstream classification dataset:

python3 finetune/run_classifier.py --pretrained_model_path models/cluecorpussmall_lstm_lm_model.bin \
                                   --vocab_path models/google_zh_vocab.txt \
                                   --config_path models/rnn/lstm_config.json \
                                   --train_path datasets/book_review/train.tsv \
                                   --dev_path datasets/book_review/dev.tsv \
                                   --test_path datasets/book_review/test.tsv \
                                   --learning_rate 1e-3 --epochs_num 5 --batch_size 64 \
                                   --pooling mean

python3 inference/run_classifier_infer.py --load_model_path models/finetuned_model.bin \
                                          --vocab_path models/google_zh_vocab.txt \
                                          --config_path models/rnn/lstm_config.json \
                                          --test_path datasets/book_review/test_nolabel.tsv \
                                          --prediction_path datasets/book_review/prediction.tsv \
                                          --pooling mean --labels_num 2

We can achieve around 85 accuracy on testset, which is a competitive result. Using the same LSTM encoder without pre-training only achieves around 81 accuracy.

UER-py also includes many other pre-training models. We download cluecorpussmall_elmo_model.bin for pre-trained ELMo model. The model is pre-trained on CLUECorpusSmall corpus for 500,000 steps:

python3 preprocess.py --corpus_path corpora/cluecorpussmall.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 256 --data_processor bilm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/rnn/bilstm_config.json \
                    --output_model_path models/cluecorpussmall_elmo_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 500000 --save_checkpoint_steps 100000 \
                    --learning_rate 5e-4 --batch_size 64

We remove the training step suffix of pre-trained model. Then we do further pre-training and fine-tune on book review sentiment classification dataset:

python3 preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 192 --data_processor bilm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/cluecorpussmall_elmo_model.bin \
                    --config_path models/rnn/bilstm_config.json \
                    --output_model_path models/book_review_elmo_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 5000 --save_checkpoint_steps 2500 \
                    --learning_rate 5e-4 --batch_size 64

mv models/book_review_elmo_model.bin-5000 models/book_review_elmo_model.bin

python3 finetune/run_classifier.py --pretrained_model_path models/book_review_elmo_model.bin \
                                   --vocab_path models/google_zh_vocab.txt \
                                   --config_path models/rnn/bilstm_config.json \
                                   --train_path datasets/book_review/train.tsv \
                                   --dev_path datasets/book_review/dev.tsv \
                                   --test_path datasets/book_review/test.tsv \
                                   --learning_rate 5e-4 --epochs_num 5 --batch_size 64 --seq_length 192 \
                                   --pooling max

python3 inference/run_classifier_infer.py --load_model_path models/finetuned_model.bin \
                                          --vocab_path models/google_zh_vocab.txt \
                                          --config_path models/rnn/bilstm_config.json \
                                          --test_path datasets/book_review/test_nolabel.tsv \
                                          --prediction_path datasets/book_review/prediction.tsv \
                                          --seq_length 192 --pooling max --labels_num 2

corpora/book_review.txt is obtained from book review sentiment classification dataset and labels are removed.

The example of fine-tuning GatedCNN on classification dataset:

python3 finetune/run_classifier.py --pretrained_model_path models/cluecorpussmall_gatedcnn_lm_model.bin \
                                   --vocab_path models/google_zh_vocab.txt \
                                   --config_path models/cnn/gatedcnn_9_config.json \
                                   --train_path datasets/book_review/train.tsv \
                                   --dev_path datasets/book_review/dev.tsv \
                                   --test_path datasets/book_review/test.tsv \
                                   --learning_rate 5e-5 --epochs_num 5 --batch_size 64 \
                                   --pooling mean

python3 inference/run_classifier_infer.py --load_model_path models/finetuned_model.bin \
                                          --vocab_path models/google_zh_vocab.txt \
                                          --config_path models/cnn/gatedcnn_9_config.json \
                                          --test_path datasets/book_review/test_nolabel.tsv \
                                          --prediction_path datasets/book_review/prediction.tsv \
                                          --pooling mean --labels_num 2

Users can download cluecorpussmall_gatedcnn_lm_model.bin from here. The model is pre-trained on CLUECorpusSmall corpus for 500,000 steps:

python3 preprocess.py --corpus_path corpora/cluecorpussmall.txt --vocab_path models/google_zh_vocab.txt \
                      --dataset_path dataset.pt --processes_num 8 --seq_length 256 --data_processor lm

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/cnn/gatedcnn_9_config.json \
                    --output_model_path models/cluecorpussmall_gatedcnn_lm_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 500000 --save_checkpoint_steps 100000 --report_steps 100 \
                    --learning_rate 1e-4 --batch_size 64

Downstream task fine-tuning and text generation with language model

The example of fine-tuning GPT-2 on classification dataset:

python3 finetune/run_classifier.py --pretrained_model_path models/cluecorpussmall_gpt2_seq1024_model.bin \
                                   --vocab_path models/google_zh_vocab.txt \
                                   --config_path models/gpt2/config.json \
                                   --train_path datasets/book_review/train.tsv \
                                   --dev_path datasets/book_review/dev.tsv \
                                   --test_path datasets/book_review/test.tsv \
                                   --epochs_num 3 --batch_size 32 \
                                   --pooling mean

The example of using GPT-2 to generate text:

python3 scripts/generate_lm.py --load_model_path models/cluecorpussmall_gpt2_seq1024_model.bin \
                               --vocab_path models/google_zh_vocab.txt \
                               --config_path models/gpt2/config.json \
                               --test_path beginning.txt --prediction_path generated_text.txt \
                               --seq_length 128

Users can download cluecorpussmall_gpt2_seq1024_model.bin from here. beginning.txt contains the beginning of a document.

The example of using LSTM language model to generate text:

python3 scripts/generate_lm.py --load_model_path models/cluecorpussmall_lstm_lm_model.bin \
                               --vocab_path models/google_zh_vocab.txt \
                               --config_path models/rnn/lstm_config.json \
                               --test_path beginning.txt --prediction_path generated_text.txt \
                               --seq_length 128

The example of using GatedCNN language model to generate text:

python3 scripts/generate_lm.py --load_model_path models/cluecorpussmall_gatedcnn_lm_model.bin \
                               --vocab_path models/google_zh_vocab.txt \
                               --config_path models/cnn/gatedcnn_9_config.json \
                               --test_path beginning.txt --prediction_path generated_text.txt \
                               --seq_length 128

Using different tokenizers and vocabularies

In most cases, we use --vocab_path models/google_zh_vocab.txt and --tokenizer bert to tokenize the text. Since most scripts in this project use --tokenizer bert in default, --tokenizer is not usually explicitly specified. Next we show more use cases of tokenizers and vocabularies.

--tokenizer bert is based on character when processing Chinese. To pre-train word-based model and fine-tine it, we firstly do word segmentation on corpora/book_review.txt corpus and words are separated by spaces. Then we build vocabulary based on the corpus:

python3 scripts/build_vocab.py --corpus_path corpora/book_review_seg.txt \
                               --output_path models/book_review_word_vocab.txt \
                               --delimiter space --workers_num 8 --min_count 5

--tokenizer space is used in pre-process and pre-training stages since spaces are used to separate words. The examples of pre-process and pre-train word-based model:

python3 preprocess.py --corpus_path corpora/book_review_seg.txt \
                      --vocab_path models/book_review_word_vocab.txt --tokenizer space \
                      --dataset_path book_review_word_dataset.pt \
                      --processes_num 8 --seq_length 128 --dynamic_masking --data_processor mlm

python3 pretrain.py --dataset_path book_review_word_dataset.pt \
                    --vocab_path models/book_review_word_vocab.txt  --tokenizer space \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/book_review_word_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 5000 --save_checkpoint_steps 2500 --report_steps 500 \
                    --learning_rate 1e-4 --batch_size 64 \
                    --data_processor mlm --target mlm

In fine-tuning and inference stages, we also need to explicitly specify --vocab_path models/book_review_word_vocab.txt and --tokenizer space. The text in train/dev/test datasets (text_a and text_b columns) should be processed by the same word segmentation tool. We do word segmentation on files in datasets/book_review/ folder and put them in datasets/book_review_seg/ folder:

mv models/book_review_word_model.bin-5000 models/book_review_word_model.bin

python3 finetune/run_classifier.py --pretrained_model_path models/book_review_word_model.bin \
                                   --vocab_path models/book_review_word_vocab.txt --tokenizer space \
                                   --config_path models/bert/base_config.json \
                                   --train_path datasets/book_review_seg/train.tsv \
                                   --dev_path datasets/book_review_seg/dev.tsv \
                                   --test_path datasets/book_review_seg/test.tsv \
                                   --epochs_num 3 --batch_size 32

python3 inference/run_classifier_infer.py --load_model_path models/finetuned_model.bin \
                                          --vocab_path models/book_review_word_vocab.txt --tokenizer space \
                                          --config_path models/bert/base_config.json \
                                          --test_path datasets/book_review_seg/test_nolabel.tsv \
                                          --prediction_path datasets/book_review_seg/prediction.tsv \
                                          --labels_num 2

The example of using SentencePiece:

python3 preprocess.py --corpus_path corpora/book_review.txt \
                      --spm_model_path models/cluecorpussmall_spm.model \
                      --dataset_path book_review_word_sentencepiece_dataset.pt \
                      --processes_num 8 --seq_length 128 --dynamic_masking \
                      --data_processor mlm

python3 pretrain.py --dataset_path book_review_word_sentencepiece_dataset.pt \
                    --spm_model_path models/cluecorpussmall_spm.model \
                    --output_model_path models/book_review_word_sentencepiece_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 5000 --save_checkpoint_steps 2500 --report_steps 500 \
                    --learning_rate 1e-4 --batch_size 64 \
                    --data_processor mlm --target mlm

mv models/book_review_word_sentencepiece_model.bin-5000 models/book_review_word_sentencepiece_model.bin

python3 finetune/run_classifier.py --pretrained_model_path models/book_review_word_sentencepiece_model.bin \
                                   --spm_model_path models/cluecorpussmall_spm.model \
                                   --config_path models/bert/base_config.json \
                                   --train_path datasets/book_review/train.tsv \
                                   --dev_path datasets/book_review/dev.tsv \
                                   --test_path datasets/book_review/test.tsv \
                                   --epochs_num 3 --batch_size 32

python3 inference/run_classifier_infer.py --load_model_path models/finetuned_model.bin \
                                          --spm_model_path models/cluecorpussmall_spm.model \
                                          --config_path models/bert/base_config.json \
                                          --test_path datasets/book_review/test_nolabel.tsv \
                                          --prediction_path datasets/book_review/prediction.tsv \
                                          --labels_num 2

The text is tokenized by the SentencePiece model trained on the CLUECorpusSmall corpus (--spm_model_path models/cluecorpussmall_spm.model).

To use character-based tokenizer, one can use --vocab_path models/google_zh_vocab.txt and --tokenizer char to substitute --spm_model_path models/cluecorpussmall_spm.model and other options are the same as above. models/google_zh_vocab.txt can be used since it is based on character for Chinese.

More details can be found in Tokenization and vocabulary.

Clone this wiki locally