# Environment

In [134]:
try:
    import transformers
except:
    !pip install transformers

In [198]:
import transformers
import pickle
import random
import math

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

In [199]:
torch.__version__

'1.12.1+cu113'

In [200]:
transformers.__version__

'4.21.3'

# Global Config

In [234]:
max_length = 128
batch_size = 32
n_epochs = 10

log_after_step = 20

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:", device)

device: cuda


In [227]:
test_mode = False
if test_mode:
    n_epochs = 1000
    data_length = 64
    batch_size = 32
    log_after_step = 1

In [203]:
!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


# Data

In [204]:
class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        src = self.train_data[index]['src']
        tgt = self.train_data[index]['tgt']
        return src, tgt

    def __len__(self):
        if test_mode:
            return data_length
        return len(self.train_data)

In [205]:
train_data = CSCDataset()

In [206]:
train_data.__getitem__(0)

('纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。')

## Dataloader

In [207]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [208]:
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    # src = tokenizer(src, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)
    inputs = tokenizer(tgt, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)

    input_ids_with_noise = []
    targets = torch.zeros(len(batch), max_length)
    for i, indices in enumerate(inputs['input_ids']):
        sentence_len = len(indices[indices != 0]) -1
        noise_len = math.floor(sentence_len * 0.15)
        noise_indices = random.sample(range(1, sentence_len), noise_len)
        noise = random.sample(range(1000, 10000), noise_len)
        for j, index in enumerate(noise_indices):
            indices[index] = noise[j]

        input_ids_with_noise.append(indices)

        targets[i][noise_indices] = 1

    inputs['input_ids'] = torch.stack(input_ids_with_noise)

    return inputs, targets

In [209]:
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn)

In [210]:
inputs, targets = next(iter(train_loader))

In [211]:
''.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

'[CLS]蔺约早##ft作为基准的低硫轻油，五月份交割价铆升一点三四杜元，来到每桶二十八点慟五美元，而上周五痍下滲愉美元以上。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]'

In [212]:
targets

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

# Model

In [213]:
# class BertModelWrapper(transformers.models.bert.modeling_bert.BertModel):
#
#     def __init__(self):
#         super =

In [214]:
class CSCModel(nn.Module):

    def __init__(self):
        super(CSCModel, self).__init__()

        self.semantic_encoder = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")

        self.output_layer = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        outputs = self.semantic_encoder(**inputs)
        return self.output_layer(outputs.last_hidden_state).squeeze(2)

In [215]:
model = CSCModel()

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [216]:
model(inputs)

tensor([[0.4610, 0.4748, 0.4595,  ..., 0.4546, 0.4585, 0.4424],
        [0.4506, 0.4551, 0.4781,  ..., 0.4368, 0.4428, 0.4432],
        [0.4885, 0.4994, 0.5231,  ..., 0.4566, 0.4671, 0.4671],
        ...,
        [0.4670, 0.4592, 0.4707,  ..., 0.4326, 0.4310, 0.4464],
        [0.4498, 0.4499, 0.4305,  ..., 0.4628, 0.4309, 0.4490],
        [0.4372, 0.5033, 0.5079,  ..., 0.4355, 0.4388, 0.4360]],
       grad_fn=<SqueezeBackward1>)

# Train

In [217]:
model = CSCModel()
model = model.to(device)
model = model.train()

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [218]:
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [239]:
step = 0
total_loss = 0.
total_correct = 0
total_num = 0
total_correct_wrong_char = 0
total_wrong_char = 0

for epoch in range(n_epochs):

    for inputs_, targets_ in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1

        total_loss += loss.detach().item()
        total_correct += ((outputs >= 0.5).int() == targets.int()).sum().item()
        total_num += len(targets.flatten())

        # 统计成功预测错字的数量
        total_correct_wrong_char += (outputs >= 0.5)[targets == 1].sum().item()
        # 统计错字的数量
        total_wrong_char += (targets == 1).sum().item()

        if step % log_after_step == 0:
            print("Epoch {}, Step {}, loss {:.5f}, accuracy {:.4f}, recall {:.4f}".format(epoch, step,
                                                                                          total_loss / log_after_step,
                                                                                          total_correct / total_num,
                                                                                          total_correct_wrong_char / total_wrong_char))
            total_loss = 0.
            total_correct = 0
            total_num = 0
            total_correct_wrong_char = 0
            total_wrong_char = 0


Epoch 0, Step 20, loss 0.00186, accuracy 0.9997, recall 0.9977


KeyboardInterrupt: 

# Inference

In [220]:
model = torch.load('csc-model.pt', map_location='cpu')

FileNotFoundError: [Errno 2] No such file or directory: 'csc-model.pt'

In [261]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(inputs.to(device))

    segments = []
    last_index = 0
    for index in torch.where(outputs[0, 1:-1] >= 0.5)[0].tolist():
        segments.append(text[last_index:index])
        segments.append("\033[1;31m" + text[index] + "\033[0m")
        last_index = index + 1
    segments.append(text[last_index:])

    return ''.join(segments), outputs

In [264]:
text, output = predict("昨天你核稿了是不是")
print(text)
print(output)

昨天你核[1;31m稿[0m了是不是
tensor([[5.7550e-04, 3.3457e-04, 3.9097e-04, 3.1344e-04, 1.2680e-02, 9.9691e-01,
         3.6202e-04, 7.9070e-04, 3.2427e-04, 3.7783e-04, 4.9322e-04]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)


In [248]:
text

'今\x1b[1;31m填\x1b[0\x1b[1;31mm\x1b[0m非常难受，因为女朋友药跟我闹分手'

# Evaluation

In [236]:
with open("data/test.sighan15.pkl", mode='br') as f:
    test_data = pickle.load(f)

In [238]:
total_num = 0
total_correct = 0

total_recall_num = 0
total_recall_correct = 0

total_precision_num = 0
total_precision_correct = 0

prograss = tqdm(range(len(test_data)))
for i in prograss:
    src, tgt = test_data[i]['src'], test_data[i]['tgt']

    _, output = predict(src)
    target = (torch.tensor(test_data[i]['src_idx']) != torch.tensor(test_data[i]['tgt_idx'])).int()[1:-1].to(device)

    if len(output) != len(target):
        print("第%d条数据异常" % i)


    total_num += 1
    if (output != target).sum().item() == 0:
        total_correct += 1

    total_recall_correct += output[target == 1].sum().item()
    total_recall_num += (target == 1).sum().item()

    total_precision_num += (output == 1).sum().item()
    total_precision_correct += (target[output == 1]).sum().item()

    accuracy = total_correct / total_num
    recall = total_recall_correct / (total_recall_num + 1e-9)
    precision = total_precision_correct / (total_precision_num + 1e-9)

    prograss.set_postfix({
        'accuracy': accuracy,
        'recall': recall,
        'precision': precision
    })

100%|██████████| 1100/1100 [00:10<00:00, 106.32it/s, accuracy=0.528, recall=0.413, precision=0.445]
