# Environment

In [1]:
try:
    import transformers
except:
    !pip install transformers

In [2]:
import transformers
import pickle

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

In [3]:
transformers.__version__

'4.21.3'

# Global Config

In [4]:
max_length = 128
batch_size = 32

log_after_step = 20

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


# Data

In [6]:
class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        src = self.train_data[index]['src']
        tgt = self.train_data[index]['tgt']
        return src, tgt

    def __len__(self):
        return len(self.train_data)

In [7]:
train_data = CSCDataset()

In [8]:
train_data.__getitem__(0)

('纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。')

## Dataloader

In [9]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [10]:
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    src = tokenizer(src, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)
    tgt = tokenizer(tgt, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)

    return src, (src['input_ids'] != tgt['input_ids']).float()

In [11]:
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn)

In [12]:
inputs, targets = next(iter(train_loader))

In [13]:
targets[0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])

# Model

In [14]:
class CSCModel(nn.Module):

    def __init__(self):
        super(CSCModel, self).__init__()

        self.semantic_encoder = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
        self.word_embeddings = self.semantic_encoder.get_input_embeddings()

        self.output_layer = nn.Sequential(
            nn.Linear(768 * 3, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        token_num = inputs['input_ids'].size(1)
        outputs = self.semantic_encoder(**inputs)
        word_embeddings = self.word_embeddings(inputs['input_ids'])
        cls_outputs = outputs.last_hidden_state[:, 0:1, :].repeat(1, token_num, 1)
        outputs = torch.concat([outputs.last_hidden_state, word_embeddings, cls_outputs], dim=2)
        return self.output_layer(outputs).squeeze(2)

In [15]:
test_model = CSCModel()
test_model(inputs).size()

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([32, 128])

# Train

In [16]:
model = CSCModel()
model = model.to(device)
model = model.train()

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [26]:
step = 0
total_loss = 0.
total_correct = 0
total_num = 0
total_correct_wrong_char = 0
total_wrong_char = 0
total_precision_num = 0
total_precision_correct = 0

for epoch in range(10):

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1

        total_loss += loss.detach().item()
        total_correct += ((outputs >= 0.5).int() == targets.int()).sum().item()
        total_num += len(targets.flatten())

        predicts = outputs >= 0.5
        # 统计成功预测错字的数量
        total_correct_wrong_char += predicts[targets == 1].sum().item()
        # 统计错字的数量
        total_wrong_char += (targets == 1).sum().item()

        total_precision_num += predicts.sum().item()
        total_precision_correct += (targets[predicts == 1]).sum().item()

        if step % log_after_step == 0:
            print("Epoch {}, Step {}, loss {:.5f}, accuracy {:.4f}, recall {:.4f}, precision {:.4f}".format(epoch, step,
                                                                                                            total_loss / log_after_step,
                                                                                                            total_correct / (total_num + 1e-9),
                                                                                                            total_correct_wrong_char / (total_wrong_char + 1e-9),
                                                                                                            total_precision_correct / (total_precision_num + 1e-9)))
            total_loss = 0.
            total_correct = 0
            total_num = 0
            total_correct_wrong_char = 0
            total_wrong_char = 0
            total_precision_num = 0
            total_precision_correct = 0


Epoch 0, Step 20, loss 0.00039, accuracy 0.9999, recall 0.9977, precision 0.9954
Epoch 0, Step 40, loss 0.00716, accuracy 0.9977, recall 0.8620, precision 0.9167
Epoch 0, Step 60, loss 0.00898, accuracy 0.9971, recall 0.8492, precision 0.9057
Epoch 0, Step 80, loss 0.00721, accuracy 0.9976, recall 0.8489, precision 0.9351
Epoch 0, Step 100, loss 0.00465, accuracy 0.9984, recall 0.9059, precision 0.9473
Epoch 0, Step 120, loss 0.00419, accuracy 0.9986, recall 0.9299, precision 0.9416
Epoch 0, Step 140, loss 0.00474, accuracy 0.9986, recall 0.9250, precision 0.9494
Epoch 0, Step 160, loss 0.00634, accuracy 0.9980, recall 0.8892, precision 0.9346
Epoch 0, Step 180, loss 0.00508, accuracy 0.9984, recall 0.9035, precision 0.9567
Epoch 0, Step 200, loss 0.00470, accuracy 0.9985, recall 0.9203, precision 0.9363
Epoch 0, Step 220, loss 0.00440, accuracy 0.9986, recall 0.9206, precision 0.9521
Epoch 0, Step 240, loss 0.00372, accuracy 0.9989, recall 0.9479, precision 0.9575
Epoch 0, Step 260, l

KeyboardInterrupt: 

# Inference

In [18]:
model = torch.load('csc-model.pt', map_location='cpu')

FileNotFoundError: [Errno 2] No such file or directory: 'csc-model.pt'

In [27]:
model = model.eval()

In [30]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(inputs.to(device))

    segments = []
    last_index = 0
    for index in torch.where(outputs[0, 1:-1] >= 0.5)[0].tolist():
        segments.append(text[last_index:index])
        segments.append("\033[1;31m" + text[index] + "\033[0m")
        last_index = index + 1
    segments.append(text[last_index:])

    return ''.join(segments), (outputs[0, 1:-1] >= 0.5).int()

In [32]:
text, output = predict("怎么越来越才了")
print(text)
print(output)

怎么越来越[1;31m才[0m了
tensor([0, 0, 0, 0, 0, 1, 0], device='cuda:0', dtype=torch.int32)


# Evaluation

In [33]:
with open("data/test.sighan15.pkl", mode='br') as f:
    test_data = pickle.load(f)

In [34]:
total_num = 0
total_correct = 0

total_recall_num = 0
total_recall_correct = 0

total_precision_num = 0
total_precision_correct = 0

prograss = tqdm(range(len(test_data)))
for i in prograss:
    src, tgt = test_data[i]['src'], test_data[i]['tgt']

    _, output = predict(src)
    target = (torch.tensor(test_data[i]['src_idx']) != torch.tensor(test_data[i]['tgt_idx'])).int()[1:-1].to(device)

    if len(output) != len(target):
        print("第%d条数据异常" % i)

    total_num += 1
    if (output != target).sum().item() == 0:
        total_correct += 1

    total_recall_correct += output[target == 1].sum().item()
    total_recall_num += (target == 1).sum().item()

    total_precision_num += (output == 1).sum().item()
    total_precision_correct += (target[output == 1]).sum().item()

    accuracy = total_correct / total_num
    recall = total_recall_correct / (total_recall_num + 1e-9)
    precision = total_precision_correct / (total_precision_num + 1e-9)

    prograss.set_postfix({
        'accuracy': accuracy,
        'recall': recall,
        'precision': precision
    })

100%|██████████| 1100/1100 [00:09<00:00, 112.68it/s, accuracy=0.695, recall=0.656, precision=0.725]
