In [1]:
import copy
import torch
import pickle

from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

Device: cuda


# 模型构建

In [3]:
class CorrectionNetwork(nn.Module):

    def __init__(self):
        super(CorrectionNetwork, self).__init__()
        # BERT分词器
        self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
        # BERT
        self.bert = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
        # BERT的word embedding，本质就是个nn.Embedding
        self.word_embedding_table = self.bert.get_input_embeddings()
        # 预测层。hidden_size是词向量的大小，len(self.tokenizer)是词典大小
        self.dense_layer = nn.Linear(self.bert.config.hidden_size, len(self.tokenizer))

    def forward(self, inputs, word_embeddings, detect_hidden_states):
        """
        Correction Network的前向传递
        :param inputs: inputs为tokenizer对中文文本的分词结果，
                       里面包含了token对一个的index，attention_mask等
        :param word_embeddings: 使用BERT的word_embedding对token进行embedding后的结果
        :param detect_hidden_states: Detection Network输出hidden state
        :return: Correction Network对个token的预测结果。
        """
        # 1. 使用bert进行前向传递
        bert_outputs = self.bert(token_type_ids=inputs['token_type_ids'],
                                 attention_mask=inputs['attention_mask'],
                                 inputs_embeds=word_embeddings)
        # 2. 将bert的hidden_state和Detection Network的hidden state进行融合。
        hidden_states = bert_outputs['last_hidden_state'] + detect_hidden_states
        # 3. 最终使用全连接层进行token预测
        return self.dense_layer(hidden_states)

    def get_inputs_and_word_embeddings(self, sequences, max_length=128):
        """
        对中文序列进行分词和word embeddings处理
        :param sequences: 中文文本序列。例如: ["鸡你太美", "哎呦，你干嘛！"]
        :param max_length: 文本的最大长度，不足则进行填充，超出进行裁剪。
        :return: tokenizer的输出和word embeddings.
        """
        inputs = self.tokenizer(sequences, padding='max_length', max_length=max_length, return_tensors='pt',
                                truncation=True).to(device)
        # 使用BERT的work embeddings对token进行embedding，这里得到的embedding并不包含position embedding和segment embedding
        word_embeddings = self.word_embedding_table(inputs['input_ids'])
        return inputs, word_embeddings

In [4]:
class DetectionNetwork(nn.Module):

    def __init__(self, position_embeddings, transformer_blocks, hidden_size):
        """
        :param position_embeddings: bert的position_embeddings，本质是一个nn.Embedding
        :param transformer: BERT的前两层transformer_block，其是一个ModuleList对象
        """
        super(DetectionNetwork, self).__init__()
        self.position_embeddings = position_embeddings
        self.transformer_blocks = transformer_blocks

        # 定义最后的预测层，预测哪个token是错误的
        self.dense_layer = nn.Sequential(
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )

    def forward(self, word_embeddings):
        # 获取token序列的长度，这里为128
        sequence_length = word_embeddings.size(1)
        # 生成position embedding
        position_embeddings = self.position_embeddings(torch.LongTensor(range(sequence_length)).to(device))
        # 融合work_embedding和position_embedding
        x = word_embeddings + position_embeddings
        # 将x一层一层的使用transformer encoder进行向后传递
        for transformer_layer in self.transformer_blocks:
            x = transformer_layer(x)[0]

        # 最终返回Detection Network输出的hidden states和预测结果
        hidden_states = x
        return hidden_states, self.dense_layer(hidden_states)

In [5]:
class MDCSpellModel(nn.Module):

    def __init__(self):
        super(MDCSpellModel, self).__init__()
        # 构造Correction Network
        self.correction_network = CorrectionNetwork()
        self._init_correction_dense_layer()

        # 构造Detection Network
        # position embedding使用BERT的
        position_embeddings = self.correction_network.bert.embeddings.position_embeddings
        # 作者在论文中提到的，Detection Network的Transformer使用BERT的权重
        # 所以我这里直接克隆BERT的前两层Transformer来完成这个动作
        transformer = copy.deepcopy(self.correction_network.bert.encoder.layer[:2])
        # 提取BERT的词向量大小
        hidden_size = self.correction_network.bert.config.hidden_size

        # 构造Detection Network
        self.detection_network = DetectionNetwork(position_embeddings, transformer, hidden_size)

    def forward(self, sequences, max_length=128):
        # 先获取word embedding，Correction Network和Detection Network都要用
        inputs, word_embeddings = self.correction_network.get_inputs_and_word_embeddings(sequences, max_length)
        # Detection Network进行前向传递，获取输出的Hidden State和预测结果
        hidden_states, detection_outputs = self.detection_network(word_embeddings)
        # Correction Network进行前向传递，获取其预测结果
        correction_outputs = self.correction_network(inputs, word_embeddings, hidden_states)
        # 返回Correction Network 和 Detection Network 的预测结果。
        # 在计算损失时`[PAD]`token不需要参与计算，所以这里将`[PAD]`部分全都变为0
        return correction_outputs, detection_outputs.squeeze(2) * inputs['attention_mask']

    def _init_correction_dense_layer(self):
        """
        原论文中提到，使用Word Embedding的weight来对Correction Network进行初始化
        """
        self.correction_network.dense_layer.weight.data = self.correction_network.word_embedding_table.weight.data

定义好模型后，我们来简单的尝试一下：

In [6]:
model = MDCSpellModel().to(device)
correction_outputs, detection_outputs = model(["鸡你太美", "哎呦，你干嘛！"])
print("correction_outputs shape:", correction_outputs.size())
print("detection_outputs shape:", detection_outputs.size())

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


correction_outputs shape: torch.Size([2, 128, 21128])
detection_outputs shape: torch.Size([2, 128])


# 损失函数

In [7]:
class MDCSpellLoss(nn.Module):

    def __init__(self, coefficient=0.85):
        super(MDCSpellLoss, self).__init__()
        # 定义Correction Network的Loss函数
        self.correction_criterion = nn.CrossEntropyLoss(ignore_index=0)
        # 定义Detection Network的Loss函数，因为是二分类，所以用Binary Cross Entropy
        self.detection_criterion = nn.BCELoss()
        # 权重系数
        self.coefficient = coefficient

    def forward(self, correction_outputs, correction_targets, detection_outputs, detection_targets):
        """
        :param correction_outputs: Correction Network的输出，Shape为(batch_size, sequence_length, hidden_size)
        :param correction_targets: Correction Network的标签，Shape为(batch_size, sequence_length)
        :param detection_outputs: Detection Network的输出，Shape为(batch_size, sequence_length)
        :param detection_targets: Detection Network的标签，Shape为(batch_size, sequence_length)
        :return:
        """
        # 计算Correction Network的loss，因为Shape维度为3，所以要把batch_size和sequence_length进行合并才能计算
        correction_loss = self.correction_criterion(correction_outputs.view(-1, correction_outputs.size(2)),
                                                    correction_targets.view(-1))
        # 计算Detection Network的loss
        detection_loss = self.detection_criterion(detection_outputs, detection_targets)
        # 对两个loss进行加权平均
        return self.coefficient * correction_loss + (1 - self.coefficient) * detection_loss

# 模型训练

作者首先使用了Wang271K数据集进行对模型进行训练，然后又使用SIGHAN训练集对模型进行fine-tune。这里我就不进行fine-tune了，直接进行训练。

## 构造Dataset

In [8]:
class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        src = self.train_data[index]['src']
        tgt = self.train_data[index]['tgt']
        return src, tgt

    def __len__(self):
        return len(self.train_data)

In [9]:
train_data = CSCDataset()

In [10]:
train_data.__getitem__(0)

('纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。')

# 构造Dataloader

In [11]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [12]:
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    src_tokens = tokenizer(src, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']
    tgt_tokens = tokenizer(tgt, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']

    correction_targets = tgt_tokens
    detection_targets = (src_tokens != tgt_tokens).float()
    return src, correction_targets, detection_targets

In [13]:
train_loader = DataLoader(train_data, batch_size=32, collate_fn=collate_fn, shuffle=True)

## 训练

In [14]:
model = model.to(device)
model = model.train()

In [15]:
criterion = MDCSpellLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [16]:
total_loss = 0.


for sequences, correction_targets, detection_targets in train_loader:
    correction_targets, detection_targets = correction_targets.to(device), detection_targets.to(device)
    correction_outputs, detection_outputs = model(sequences)
    loss = criterion(correction_outputs, correction_targets, detection_outputs, detection_targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    total_loss += loss.detach().item()

    print(loss)

tensor(3.8491, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.6342, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.8371, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2116, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.9505, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.7459, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.7200, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6434, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6498, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6412, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5951, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6212, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5708, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4601, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5493, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5711, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4655, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4438, device='cuda:0', grad_fn=<AddBack

KeyboardInterrupt: 