# 文本填空

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

from datasets import load_dataset

## 1. 定义数据集

In [6]:
class MyDataset(Dataset):
    def __init__(self, split: str):
        super(MyDataset, self).__init__()

        dataset = load_dataset('seamew/ChnSentiCorp', split=split)

        # 过滤低于30长度的句子
        def f(data):
            return len(data['text']) > 30

        self.dataset = dataset.filter(f)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        text = self.dataset['text'][item]
        # label = self.dataset['label'][item]

        return text


dataset = MyDataset('test')
print(dataset.__len__())
# print(dataset.__getitem__(0))

Using custom data configuration default
Reusing dataset chn_senti_corp (C:\Users\Jejune\.cache\huggingface\datasets\seamew___chn_senti_corp\default\0.0.0\1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)
Loading cached processed dataset at C:\Users\Jejune\.cache\huggingface\datasets\seamew___chn_senti_corp\default\0.0.0\1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85\cache-6f2a18d626566fa2.arrow


1145


## 2. 定义tokenizer

In [7]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

tokenizer

PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

# 3. 定义批处理函数

In [9]:
tokenizer.mask_token

'[MASK]'

In [10]:
def collate_fn(data):
    # texts = [i[0] for i in data]
    # labels = [i[1] for i in data]

    # 编码
    data = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=data,
        truncation=True,
        padding='max_length',
        max_length=30,
        return_tensors='pt',
        return_length=True
    )
    # 编码后的数字
    input_ids = data['input_ids']
    # pad位置是0，其他位置是1
    attention_mask = data['attention_mask']
    # token_type_ids: 第一句和特殊符号是0，其余是1
    token_type_ids = data['token_type_ids']

    labels = input_ids[:, 15].reshape(-1).clone()
    input_ids[:, 15] = tokenizer.get_vocab()[tokenizer.mask_token]
    return input_ids, attention_mask, token_type_ids, labels


# 保存数据
input_ids, attention_mask, token_type_ids, labels = collate_fn(dataset)
torch.save(input_ids, 'data/input_ids.pt')
torch.save(attention_mask, 'data/attention_mask.pt')
torch.save(token_type_ids, 'data/token_type_ids.pt')
torch.save(labels, 'data/labels.pt')

#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True
)


In [11]:
for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    break

print(len(loader))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

71


(torch.Size([16, 30]),
 torch.Size([16, 30]),
 torch.Size([16, 30]),
 tensor([4638, 1044, 3315, 3221, 4281, 4263, 8303, 6432, 7231,  754, 5326, 3141,
         3123,  712, 4495, 1963]))

## 4.定义Bert预训练模型

In [12]:
from transformers import BertModel

pretrained = BertModel.from_pretrained('bert-base-chinese')

#不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

# 模型试算
out = pretrained(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids
)
out.last_hidden_state.shape

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([16, 30, 768])

## 5.下游任务

In [13]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.decoder = nn.Linear(in_features=768, out_features=tokenizer.vocab_size, bias=False)
        self.bias = torch.nn.Parameter(torch.zeros(tokenizer.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )
        out = self.decoder(out.last_hidden_state[:, 15])

        return out

model = Net()

model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([16, 21128])

In [15]:
from transformers import AdamW

#训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

model.train()
for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)

    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)

        print(i, loss.item(), accuracy)

    if i == 300:
        break

0 8.925697326660156 0.125
5 8.994549751281738 0.125
10 9.03781795501709 0.0
15 8.448066711425781 0.0625
20 8.581857681274414 0.0625
25 7.799210071563721 0.1875
30 7.254517555236816 0.3125
35 7.724375247955322 0.125
40 8.473230361938477 0.0625
45 7.4295196533203125 0.1875
50 8.093338012695312 0.125
55 7.525974750518799 0.0625
60 6.866983413696289 0.125
65 7.203446865081787 0.125
70 6.602130889892578 0.125
