In [1]:
from tokenizers.implementations import BertWordPieceTokenizer
from tokenizers.processors import TemplateProcessing
from tokenizers.pre_tokenizers import Whitespace
from transformers import PreTrainedTokenizerFast, BertConfig, BertForMaskedLM, DataCollatorForLanguageModeling
import torch.utils.data as Data
from transformers import Trainer, TrainingArguments

In [2]:
tokenizer = BertWordPieceTokenizer()  # 分词器
tokenizer.pre_tokenizer = Whitespace()


files = [f"wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]
tokenizer.train(files=files,
                vocab_size=30000,
                min_frequency=2,
                special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

tokenizer.post_processor = TemplateProcessing(single="[CLS] $A [SEP]",
                                              pair="[CLS] $A [SEP] $B:1 [SEP]:1",
                                              special_tokens=[("[CLS]", tokenizer.token_to_id("[CLS]")),
                                                              ("[SEP]", tokenizer.token_to_id("[SEP]"))])

# Enable the padding
tokenizer.enable_padding(
    pad_id=tokenizer.token_to_id('[PAD]'),  # The id to be used when padding
    pad_token="[PAD]",  # The pad token to be used when padding
    pad_type_id=0)  # The type id to be used when padding

# Enable truncation
tokenizer.enable_truncation(
    max_length=512)  # 截断的最大长度






In [3]:
tokenizer.save("tokenizer-wikitext-103.json")

In [4]:
# 加载训练好的分词器
tokenizer_fast = PreTrainedTokenizerFast(tokenizer_file='tokenizer-wikitext-103.json')
tokenizer_fast

PreTrainedTokenizerFast(name_or_path='', vocab_size=30000, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={})

In [5]:
tokenizer_fast.add_special_tokens(special_tokens_dict={'eos_token': '[EOS]',
                                                       'mask_token': '[UNK]',
                                                       'pad_token': '[PAD]',
                                                       'cls_token': '[CLS]',
                                                       'sep_token': '[SEP]'})
tokenizer_fast.all_special_tokens

['[EOS]', '[SEP]', '[PAD]', '[CLS]', '[UNK]']

In [6]:
class LineByLineDataset(Data.Dataset):
    """读取与包装数据集"""

    def __init__(self, tokenizer, file_path):
        with open(file_path, encoding="utf-8") as f:
            data = f.readlines()
            print(len(data))
        data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
        print(len(data))
        batch_encoding = tokenizer(data, truncation=True, padding="max_length", return_tensors='pt', max_length=512)  # 内存消耗过大

        self.input_ids = batch_encoding['input_ids']
        self.attention_mask = batch_encoding['attention_mask']
        self.token_type_ids = batch_encoding['token_type_ids']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i):
        return {'input_ids': self.input_ids[i],
                'attention_mask': self.attention_mask[i],
                'token_type_ids': self.token_type_ids[i]}


lbld = LineByLineDataset(tokenizer_fast, files[2])
print(len(lbld))

for i in lbld:
    print(i.keys())
    break

3760
2461
2461
dict_keys(['input_ids', 'attention_mask', 'token_type_ids'])


In [7]:
# Data collator used for language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer_fast, mlm=True, mlm_probability=0.15)
data_collator

DataCollatorForLanguageModeling(tokenizer=PreTrainedTokenizerFast(name_or_path='', vocab_size=30000, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '[EOS]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[UNK]'}), mlm=True, mlm_probability=0.15, pad_to_multiple_of=None, tf_experimental_compile=False, return_tensors='pt')

In [8]:
config = BertConfig(
    vocab_size=50000,
    hidden_size=768,
    num_hidden_layers=6,
    num_attention_heads=12,
    max_position_embeddings=512
)

model = BertForMaskedLM(config)
print('No of parameters: ', model.num_parameters())

No of parameters:  81965648


In [9]:
training_args = TrainingArguments(
    output_dir='preTrained_Model',
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    save_strategy='epoch'
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=lbld,
    tokenizer=tokenizer_fast
)

trainer.train()

***** Running training *****
  Num examples = 2461
  Num Epochs = 5
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 770


Step,Training Loss
500,7.4131


Saving model checkpoint to preTrained_Model/checkpoint-154
Configuration saved in preTrained_Model/checkpoint-154/config.json
Model weights saved in preTrained_Model/checkpoint-154/pytorch_model.bin
tokenizer config file saved in preTrained_Model/checkpoint-154/tokenizer_config.json
Special tokens file saved in preTrained_Model/checkpoint-154/special_tokens_map.json
Saving model checkpoint to preTrained_Model/checkpoint-308
Configuration saved in preTrained_Model/checkpoint-308/config.json
Model weights saved in preTrained_Model/checkpoint-308/pytorch_model.bin
tokenizer config file saved in preTrained_Model/checkpoint-308/tokenizer_config.json
Special tokens file saved in preTrained_Model/checkpoint-308/special_tokens_map.json
Saving model checkpoint to preTrained_Model/checkpoint-462
Configuration saved in preTrained_Model/checkpoint-462/config.json
Model weights saved in preTrained_Model/checkpoint-462/pytorch_model.bin
tokenizer config file saved in preTrained_Model/checkpoint-462/

TrainOutput(global_step=770, training_loss=7.227665096134334, metrics={'train_runtime': 187.9602, 'train_samples_per_second': 65.466, 'train_steps_per_second': 4.097, 'total_flos': 1631901312860160.0, 'train_loss': 7.227665096134334, 'epoch': 5.0})

In [10]:
model.save_pretrained('save_pretrained_model/')

Configuration saved in save_pretrained_model/config.json
Model weights saved in save_pretrained_model/pytorch_model.bin
