# Environment

In [2]:
try:
    import transformers
except:
    !pip install transformers

In [3]:
import time
import os
import pickle

import transformers
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

In [4]:
transformers.__version__

'4.21.2'

# Global Config

In [5]:
max_length = 128
batch_size = 32
epochs = 10

log_after_step = 20

model_path = './drive/MyDrive/models/'
os.mkdir(model_path) if not os.path.exists(model_path) else ''
model_path = model_path + 'csc-model.pt'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
test_mode = False
if test_mode:
    epochs = 1000
    data_length = 10
    batch_size = 1
    log_after_step = 1
    model_path = 'csc-model.pt'

In [7]:
!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


In [8]:
!gdown '1CMNG1F_MmdrC1VGDpY6S0LCdLZcUuhhC' --output sighan.zip
!unzip sighan.zip > /dev/null

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。
拒绝访问。


# Data

In [9]:
class CSCDataset(Dataset):

    def __init__(self, training_data):
        super(CSCDataset, self).__init__()
        self.training_data = training_data

    def __getitem__(self, index):
        src = self.training_data[index]['src']
        tgt = self.training_data[index]['tgt']
        return src, tgt

    def __len__(self):
        if test_mode:
            return data_length
        return len(self.training_data)

In [10]:
with open("data/trainall.times2.pkl", mode='br') as f:
    train_data = pickle.load(f)
train_data = CSCDataset(train_data)

In [11]:
train_data.__getitem__(0)

('纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。')

## Dataloader

In [12]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [13]:
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    src = tokenizer(src, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)
    tgt = tokenizer(tgt, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)

    return src, (src['input_ids'] != tgt['input_ids']).float()

In [14]:
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, drop_last=True)

In [15]:
inputs, targets = next(iter(train_loader))

# Model

In [16]:
class LayerNorm(nn.Module):
    """
    Norm层，其实该层的作用就是BatchNorm。与`torch.nn.BatchNorm2d`的作用一致。
    torch.nn.BatchNorm2d的官方文档地址：https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html

    该LayerNorm就对应原图中 “Add & Norm”中“Norm”的部分
    """

    def __init__(self, features, eps=1e-6):
        """
        features: int类型，含义为特征数。也就是一个词向量的维度，例如128。该值一般和d_model一致。
        """
        super(LayerNorm, self).__init__()
        """
        这两个参数是BatchNorm的参数，a_2对应gamma(γ), b_2对应beta(β)。
        而nn.Parameter的作用就是将这个两个参数作为模型参数，之后要进行梯度下降。
        """
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        # epsilon，一个很小的数，防止分母为0
        self.eps = eps

    def forward(self, x):
        """
        x： 为Attention层或者Feed Forward层的输出。Shape和Encoder的输入一样。（其实整个过程中，x的shape都不会发生改变）。
            例如，x的shape为(1, 7, 128)，即batch_size为1，7个单词，每个单词是128维度的向量。
        """

        # 按最后一个维度求均值。mean的shape为 (1, 7, 1)
        mean = x.mean(-1, keepdim=True)
        # 按最后一个维度求方差。std的shape为 (1, 7, 1)
        std = x.std(-1, keepdim=True)
        # 进行归一化，详情可查阅BatchNorm相关资料。
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [17]:
class CSCModel(nn.Module):

    def __init__(self):
        super(CSCModel, self).__init__()

        self.bert = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
        self.word_embeddings = self.bert.get_input_embeddings()

        # transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=6, dim_feedforward=1024,
        #                                                        activation='gelu',
        #                                                        norm_first=True,
        #                                                        batch_first=True)
        # self.transformer = nn.TransformerEncoder(transformer_encoder_layer, num_layers=1)

        # 使用bert的transformer_encoder来初始化transformer
        self.transformer_blocks = self.bert.encoder.layer[:2]

        self.fusion_layer = nn.Sequential(
            nn.Linear(768 * 3, 768),
            nn.Sigmoid() # TODO 这里应该用什么激活函数好？
        )

        self.output_layer = nn.Sequential(
            nn.Linear(768, 1),
            nn.Sigmoid()
        )

        self.norm = LayerNorm(768)

    def forward(self, inputs):
        token_num = inputs['input_ids'].size(1)
        outputs = self.bert(**inputs)
        word_embeddings = self.word_embeddings(inputs['input_ids'])
        cls_outputs = outputs.last_hidden_state[:, 0:1, :].repeat(1, token_num, 1)
        outputs = torch.concat([outputs.last_hidden_state, word_embeddings, cls_outputs], dim=2)
        fusion_outputs = self.fusion_layer(outputs)

        x = fusion_outputs
        for transformer_layer in self.transformer_blocks:
            x = transformer_layer(x)[0]
        outputs = x

        # outputs = self.transformer(fusion_outputs)
        outputs = outputs + fusion_outputs
        outputs = self.norm(outputs)
        return self.output_layer(outputs).squeeze(2) * inputs['attention_mask']

In [18]:
test_model = CSCModel()
test_model(inputs).size()

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([32, 128])

# Train

In [24]:
model = CSCModel()
criteria = nn.BCELoss()
start_epoch = 0
total_step = 0
record = []  # 记录loss、accuracy变化等
last_10_losses = []  # 记录最近的10个loss

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [19]:
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    start_epoch = checkpoint['epoch']
    total_step = checkpoint['total_step']
    record = checkpoint['record']
    last_10_losses = checkpoint['last_10_losses']

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("恢复训练，epoch:", start_epoch)
else:
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [20]:
model = model.to(device)
model = model.train()

In [22]:
writer = SummaryWriter(log_dir='runs/csc_model')
# 恢复之前的数据
for item in record:
    step, loss, recall, precision = item
    writer.add_scalar(tag="record/loss", scalar_value=loss, global_step=step)
    writer.add_scalar(tag="record/recall", scalar_value=recall, global_step=step)
    writer.add_scalar(tag="record/precision", scalar_value=precision, global_step=step)

In [23]:
if not test_mode:
    %load_ext tensorboard
    %tensorboard --logdir=runs
    time.sleep(10)

Launching TensorBoard...

In [27]:
total_loss = 0.
total_correct = 0
total_num = 0
total_correct_wrong_char = 0
total_wrong_char = 0
total_precision_num = 0
total_precision_correct = 0

for epoch in range(start_epoch, epochs):

    step = 0

    if len(last_10_losses) <= 0:
        last_10_losses.append(9999)

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs, targets)

        # 如果loss比过去10次的平均loss大3倍的话，则不更新参数。防止异常数据使模型不稳定
        if loss < sum(last_10_losses) / len(last_10_losses) * 3:
            loss.backward()
            optimizer.step()
        optimizer.zero_grad()

        step += 1
        total_step += 1

        last_10_losses.append(loss.detach().item())
        if len(last_10_losses) > 10:
            last_10_losses = last_10_losses[-10:]

        total_loss += loss.detach().item()
        total_correct += ((outputs >= 0.5).int() == targets.int()).sum().item()
        total_num += len(targets.flatten())

        predicts = outputs >= 0.5
        # 统计成功预测错字的数量
        total_correct_wrong_char += predicts[targets == 1].sum().item()
        # 统计错字的数量
        total_wrong_char += (targets == 1).sum().item()

        total_precision_num += predicts.sum().item()
        total_precision_correct += (targets[predicts == 1]).sum().item()

        if total_step % log_after_step == 0:
            loss = total_loss / log_after_step
            recall = total_correct_wrong_char / (total_wrong_char + 1e-9)
            precision = total_precision_correct / (total_precision_num + 1e-9)

            print("Epoch {}, "
                  "Step {}/{}, "
                  "Total Step {}, "
                  "loss {:.5f}, "
                  "recall {:.4f}, "
                  "precision {:.4f}".format(epoch, step, len(train_loader), total_step,
                                            loss,
                                            recall,
                                            precision))
            writer.add_scalar(tag="record/loss", scalar_value=loss, global_step=total_step)
            writer.add_scalar(tag="record/recall", scalar_value=recall, global_step=total_step)
            writer.add_scalar(tag="record/precision", scalar_value=precision, global_step=total_step)

            record.append((total_step, loss, recall, precision,))

            total_loss = 0.
            total_correct = 0
            total_num = 0
            total_correct_wrong_char = 0
            total_wrong_char = 0
            total_precision_num = 0
            total_precision_correct = 0

    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch + 1,
        'total_step': total_step,
        'record': record,
        'last_10_losses': last_10_losses
    }, model_path)

KeyboardInterrupt: 

# Fine-tune

In [32]:
with open("sighan/Train/sighan13_training_set_simplified.pkl", mode='br') as f:
    sighan13_training_set = pickle.load(f)
with open("sighan/Train/sighan14_training_set_simplified.pkl", mode='br') as f:
    sighan14_training_set = pickle.load(f)
with open("sighan/Train/sighan15_training_set_simplified.pkl", mode='br') as f:
    sighan15_training_set = pickle.load(f)
sighan_training_data = sighan13_training_set + sighan14_training_set + sighan15_training_set

In [33]:
sighan_training_data = CSCDataset(sighan_training_data)

In [34]:
sighan_train_loader = DataLoader(sighan_training_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, drop_last=True)

In [36]:
model = model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [37]:
total_loss = 0.
total_correct = 0
total_num = 0
total_correct_wrong_char = 0
total_wrong_char = 0
total_precision_num = 0
total_precision_correct = 0

for epoch in range(0, 10):

    step = 0

    for inputs, targets in sighan_train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs, targets)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1
        total_step += 1

        total_loss += loss.detach().item()
        total_correct += ((outputs >= 0.5).int() == targets.int()).sum().item()
        total_num += len(targets.flatten())

        predicts = outputs >= 0.5
        # 统计成功预测错字的数量
        total_correct_wrong_char += predicts[targets == 1].sum().item()
        # 统计错字的数量
        total_wrong_char += (targets == 1).sum().item()

        total_precision_num += predicts.sum().item()
        total_precision_correct += (targets[predicts == 1]).sum().item()

        if total_step % log_after_step == 0:
            loss = total_loss / log_after_step
            recall = total_correct_wrong_char / (total_wrong_char + 1e-9)
            precision = total_precision_correct / (total_precision_num + 1e-9)

            print("Epoch {}, "
                  "Step {}/{}, "
                  "Total Step {}, "
                  "loss {:.5f}, "
                  "recall {:.4f}, "
                  "precision {:.4f}".format(epoch, step, len(sighan_train_loader), total_step,
                                            loss,
                                            recall,
                                            precision))

            total_loss = 0.
            total_correct = 0
            total_num = 0
            total_correct_wrong_char = 0
            total_wrong_char = 0
            total_precision_num = 0
            total_precision_correct = 0

    torch.save({
        'model': model.state_dict()
    }, model_path.replace("model.pt", "model-final.pt"))


Epoch 0, Step 1/202, Total Step 1, loss 0.01649, recall 1.0000, precision 0.0245


KeyboardInterrupt: 

# Inference

In [None]:
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    model = model.to(device)

In [38]:
model = model.eval()

In [41]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(inputs.to(device))

    segments = []
    last_index = 0
    for index in torch.where(outputs[0, 1:-1] >= 0.5)[0].tolist():
        segments.append(text[last_index:index])
        segments.append("\033[1;31m" + text[index] + "\033[0m")
        last_index = index + 1
    segments.append(text[last_index:])

    return ''.join(segments), (outputs[0, 1:-1] >= 0.5).int()

In [42]:
text, output = predict("我昨天吃了以个火聋果")
print(text)
print(output)

[1;31m我[0m[1;31m昨[0m[1;31m天[0m[1;31m吃[0m[1;31m了[0m[1;31m以[0m[1;31m个[0m[1;31m火[0m[1;31m聋[0m[1;31m果[0m
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)


# Evaluation

In [43]:
with open("sighan/Test/sighan15_test_set_simplified.pkl", mode='br') as f:
    test_data = pickle.load(f)

In [45]:
total_num = 0
total_correct = 0

total_recall_num = 0
total_recall_correct = 0

total_precision_num = 0
total_precision_correct = 0

hard_data = []
result_list = []

progress = tqdm(range(len(test_data)))
for i in progress:
    src, tgt = test_data[i]['src'], test_data[i]['tgt']

    inputs = tokenizer(src, return_tensors='pt', max_length=max_length, truncation=True)
    src_tokens = inputs['input_ids'][0][1:-1]
    tgt_tokens = tokenizer(tgt, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]

    if len(src_tokens) != len(tgt_tokens):
        print("第%d条数据异常" % i)
        continue

    output = model(inputs.to(device))[0][1:-1]

    # 预测结果和label，True表示错误token，False表示正确token
    pred = output > 0.5
    target = src_tokens != tgt_tokens

    # 找出错误token的数量
    recall_num = target.sum().item()
    total_recall_num += recall_num
    # 找出在这些错误token中，有多少是被模型正确预测出来了
    recall_correct = pred[target].sum().item()
    total_recall_correct += recall_correct

    # 找出模型认为是错误token的数量
    precision_num = pred.sum().item()
    total_precision_num += precision_num
    # 找出模型认为是错误的token中，有哪些是真正错误的
    precision_correct = target[pred].sum().item()
    total_precision_correct += precision_correct

    recall = total_recall_correct / (total_recall_num + 1e-9)
    precision = total_precision_correct / (total_precision_num + 1e-9)
    f1_score = 2 * (recall * precision) / (recall + precision + 1e-9)

    if recall_num != recall_correct or precision_correct != precision_num:
        hard_data.append((src, tgt))

    progress.set_postfix({
        'recall': recall,
        'precision': precision,
        'f1-score': f1_score
    })

    result_list.append((test_data[i]['id'], (torch.where(pred)[0] + 2).tolist()))

  8%|▊         | 86/1100 [00:07<01:32, 10.94it/s, recall=0, precision=0, f1-score=0]


KeyboardInterrupt: 

In [67]:
with open('sighan15_predict.txt', mode='w', encoding='utf-8') as f:
    lines = []
    for item in result_list:
        line = item[0]
        for i in item[1]:
            line += ', ' + str(i) + ', 鸡'
        if len(item[1]) <= 0:
            line += ', 0'

        lines.append(line + "\n")
    f.writelines(lines)