# 掩码语言训练实例

#### Step1. 导入相关包

In [1]:
from datasets import load_dataset,Dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer

2023-10-01 17:41:45.299850: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.


#### Step2. 加载数据集

In [None]:
# dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
ds = Dataset.load_from_disk("./wikoi_cn_filtered/")

In [None]:
ds

In [None]:
ds[0]

#### Step3. 数据处理

In [None]:
tokenizer = AutoTokenizer.from_pretrained("hf1/chinese-macbert-base")

def process_func(examples):
    return tokenizer(examples["completion"], max_length=384, truncation=True)

tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds.column_names)
tokenized_ds

In [None]:
from torch.utils.data import DataLoader

dl = DataLoader(tokenized_ds, batch_size=2, collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15))

In [None]:
next(enumerate(dl))

In [None]:
tokenizer.mask_token, tokenizer.mask_token_id

#### Step4. 创建模型

In [None]:
model = AutoModelForMaskedLM.from_pretrained("hf1/chinese-macbert-base")

#### Step5. 配置训练参数

In [None]:
args = TrainingArguments(
    output_dir="./masked_lm",
    per_device_train_batch_size=32,
    logging_steps=10,
    num_train_epochs=1
)

#### Step6. 创建训练器

In [None]:
trainer = Trainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15)
)

#### Step7. 模型训练

In [None]:
trainer.train()

#### Step8. 模型推理

In [None]:
from transformers import pipeline
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)
pipe("西安交通[MASK][MASK]博物馆(Xi'an Jaiotong University Museum)是一座位于西安交通大学的博物馆")

In [None]:
pipe("下面是一则[MASK][MASK]新闻，小编报道，近日，游戏产业发展的非常好!")