In [3]:
from transformers import (RobertaForMaskedLM, RobertaTokenizer, DataCollatorForLanguageModeling, Trainer,
                          TrainingArguments)
from datasets import Dataset

In [4]:
model_name = 'roberta-base'

model = RobertaForMaskedLM.from_pretrained(model_name)  # 已经训练好的预训练模型
tokenizer = RobertaTokenizer.from_pretrained(model_name)  # 使用已有的分词器

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/478M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

In [5]:
files = [f"wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]

dataset_train = Dataset.from_text(files[1])
dataset_valid = Dataset.from_text(files[2])

print(dataset_train)
print(dataset_valid)

Using custom data configuration default-5039f54a81ca91e6


Downloading and preparing dataset text/default to /root/.cache/huggingface/datasets/text/default-5039f54a81ca91e6/0.0.0...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-dcd9eeac650775a1


Dataset text downloaded and prepared to /root/.cache/huggingface/datasets/text/default-5039f54a81ca91e6/0.0.0. Subsequent calls will reuse this data.
Downloading and preparing dataset text/default to /root/.cache/huggingface/datasets/text/default-dcd9eeac650775a1/0.0.0...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Dataset text downloaded and prepared to /root/.cache/huggingface/datasets/text/default-dcd9eeac650775a1/0.0.0. Subsequent calls will reuse this data.
Dataset({
    features: ['text'],
    num_rows: 1801350
})
Dataset({
    features: ['text'],
    num_rows: 3760
})


In [7]:
def filter_func(data):
    text = data['text']
    return len(text) > 0 and not text.isspace()  # 过滤空白行


def map_func(data):
    batch_encoding = tokenizer(data['text'], truncation=True, padding="max_length", max_length=512)
    # roberta-base模型的model_input_names为:['input_ids', 'attention_mask']
    return {'input_ids': batch_encoding['input_ids'],
            'attention_mask': batch_encoding['attention_mask']}


dataset_train_filter = dataset_train.filter(filter_func)
dataset_train_map = dataset_train_filter.map(map_func, batched=True, batch_size=1000)  # 每次处理1000条数据

dataset_valid_filter = dataset_valid.filter(filter_func)
dataset_valid_map = dataset_valid_filter.map(map_func, batched=True, batch_size=1000)

  0%|          | 0/1802 [00:00<?, ?ba/s]

  0%|          | 0/1166 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

In [9]:
# 相当于torch.utils.data.DataLoader中collate_fn的作用(可以重写,参考K_demo/way_of_training/pytorch_transformer.ipynb)
# Data collator used for language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
data_collator

DataCollatorForLanguageModeling(tokenizer=PreTrainedTokenizer(name_or_path='roberta-base', vocab_size=50265, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True)}), mlm=True, mlm_probability=0.15, pad_to_multiple_of=None, tf_experimental_compile=False, return_tensors='pt')

In [None]:
training_args = TrainingArguments(
    output_dir='output_dir',
    overwrite_output_dir=True,
    max_steps=3000,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy='steps',
    eval_steps=200,
    save_steps=1000,
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
)

# 继续训练预训练模型
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset_train_map,
    eval_dataset=dataset_valid_map,
    tokenizer=tokenizer
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs
The following columns in the training set  don't have a corresponding argument in `RobertaForMaskedLM.forward` and have been ignored: text. If text are not expected by `RobertaForMaskedLM.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1165029
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 3000


Step,Training Loss,Validation Loss
200,No log,1.589433
400,No log,1.575636
600,1.657900,1.582396
800,1.657900,1.572175
1000,1.620300,1.54285
1200,1.620300,1.534524
1400,1.620300,1.513041
1600,1.593500,1.521249
1800,1.593500,1.525004
2000,1.564600,1.485153


The following columns in the evaluation set  don't have a corresponding argument in `RobertaForMaskedLM.forward` and have been ignored: text. If text are not expected by `RobertaForMaskedLM.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2461
  Batch size = 16
The following columns in the evaluation set  don't have a corresponding argument in `RobertaForMaskedLM.forward` and have been ignored: text. If text are not expected by `RobertaForMaskedLM.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2461
  Batch size = 16
The following columns in the evaluation set  don't have a corresponding argument in `RobertaForMaskedLM.forward` and have been ignored: text. If text are not expected by `RobertaForMaskedLM.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2461
  Batch size = 16
The following columns in the evaluation set  don't have a corresponding argu