# Environment

In [5]:
try:
    import transformers
except:
    !pip install transformers

In [120]:
import transformers
import pickle

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

In [7]:
transformers.__version__

'4.21.2'

# Global Config

In [8]:
max_length = 128
batch_size = 32

log_after_step = 20

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


# Data

In [10]:
class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        src = self.train_data[index]['src']
        tgt = self.train_data[index]['tgt']
        return src, tgt

    def __len__(self):
        return len(self.train_data)

In [11]:
train_data = CSCDataset()

In [12]:
train_data.__getitem__(0)

('纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。')

## Dataloader

In [13]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [14]:
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    src = tokenizer(src, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)
    tgt = tokenizer(tgt, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)

    return src, (src['input_ids'] != tgt['input_ids']).float()

In [15]:
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn)

In [16]:
inputs, targets = next(iter(train_loader))

In [17]:
targets

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

# Model

In [18]:
class CSCModel(nn.Module):

    def __init__(self):
        super(CSCModel, self).__init__()

        self.semantic_encoder = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")

        self.output_layer = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        outputs = self.semantic_encoder(**inputs)
        return self.output_layer(outputs.last_hidden_state).squeeze(2)

# Train

In [19]:
model = CSCModel()
model = model.to(device)
model = model.train()

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [20]:
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [21]:
step = 0
total_loss = 0.
total_correct = 0
total_num = 0
total_correct_wrong_char = 0
total_wrong_char = 0

for epoch in range(10):

    for inputs_, targets_ in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1

        total_loss += loss.detach().item()
        total_correct += ((outputs >= 0.5).int() == targets.int()).sum().item()
        total_num += len(targets.flatten())

        # 统计成功预测错字的数量
        total_correct_wrong_char += (outputs >= 0.5)[targets == 1].sum().item()
        # 统计错字的数量
        total_wrong_char += (targets == 1).sum().item()

        if step % log_after_step == 0:
            print("Epoch {}, Step {}, loss {:.5f}, accuracy {:.4f}, recall {:.4f}".format(epoch, step,
                                                                                          total_loss / log_after_step,
                                                                                          total_correct / total_num,
                                                                                          total_correct_wrong_char / total_wrong_char))
            total_loss = 0.
            total_correct = 0
            total_num = 0
            total_correct_wrong_char = 0
            total_wrong_char = 0


KeyboardInterrupt: 

# Inference

In [23]:
model = torch.load('csc-model.pt', map_location='cpu')

In [78]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(inputs.to(device))

    for index in torch.where(outputs[0, 1:-1] >= 0.5)[0].tolist():
        text = text[:index] + "\033[1;31m" + text[index] + "\033[0m" + text[index + 1:]

    return text, (outputs[0, 1:-1] >= 0.5).int()

In [79]:
text, output = predict("今天非常难受，因为女朋友药跟我闹分手")
print(text)
print(output)

今天非常难受，因为女朋友[1;31m药[0m跟我闹分手
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       dtype=torch.int32)


# Evaluation

In [80]:
with open("data/test.sighan15.pkl", mode='br') as f:
    test_data = pickle.load(f)

In [150]:
total_num = 0
total_correct = 0

total_recall_num = 0
total_recall_correct = 0

total_precision_num = 0
total_precision_correct = 0

prograss = tqdm(range(len(test_data)))
for i in prograss:
    src, tgt = test_data[i]['src'], test_data[i]['tgt']

    _, output = predict(src)
    target = (torch.tensor(test_data[i]['src_idx']) != torch.tensor(test_data[i]['tgt_idx'])).int()[1:-1]

    if len(output) != len(target):
        print("第%d条数据异常" % i)


    total_num += 1
    if (output != target).sum().item() == 0:
        total_correct += 1

    total_recall_correct += output[target == 1].sum().item()
    total_recall_num += (target == 1).sum().item()

    total_precision_num += (output == 1).sum().item()
    total_precision_correct += (target[output == 1]).sum().item()

    accuracy = total_correct / total_num
    recall = total_recall_correct / (total_recall_num + 1e-9)
    precision = total_precision_correct / (total_precision_num + 1e-9)

    prograss.set_postfix({
        'accuracy': accuracy,
        'recall': recall,
        'precision': precision
    })

100%|██████████| 1100/1100 [01:41<00:00, 10.82it/s, accuracy=0.716, recall=0.595, precision=0.766]
