# Environment

In [1]:
try:
    import transformers
except:
    !pip install transformers

In [2]:
import transformers
import pickle
import random
import math

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

In [3]:
torch.__version__

'1.12.1+cu113'

In [4]:
transformers.__version__

'4.21.3'

# Global Config

In [5]:
max_length = 128
batch_size = 32
n_epochs = 10

log_after_step = 50

valid_after_step = 200

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:", device)

device: cuda


In [6]:
test_mode = False
if test_mode:
    n_epochs = 1000
    data_length = 64
    batch_size = 32
    log_after_step = 1

In [7]:
!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


# Data

In [8]:
class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        src = self.train_data[index]['src']
        tgt = self.train_data[index]['tgt']
        return src, tgt

    def __len__(self):
        if test_mode:
            return data_length
        return len(self.train_data)

In [9]:
train_data = CSCDataset()

In [10]:
train_data.__getitem__(0)

('纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。')

## Dataloader

In [11]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [12]:
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    # src = tokenizer(src, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)
    inputs = tokenizer(tgt, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)

    input_ids_with_noise = []
    targets = torch.zeros(len(batch), max_length)
    for i, indices in enumerate(inputs['input_ids']):
        sentence_len = len(indices[indices != 0]) -1
        noise_len = math.floor(sentence_len * 0.15)
        noise_indices = random.sample(range(1, sentence_len), noise_len)
        noise = random.sample(range(670, 7992), noise_len)
        for j, index in enumerate(noise_indices):
            indices[index] = noise[j]

        input_ids_with_noise.append(indices)

        targets[i][noise_indices] = 1

    inputs['input_ids'] = torch.stack(input_ids_with_noise)

    return inputs, targets

In [13]:
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn)

In [14]:
inputs, targets = next(iter(train_loader))

In [15]:
''.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

'[CLS]纽约早盘作为基准的低硫趴島熠売月份交淼价攀升一点三四美元，来到每桶二十八点二五美元，而上周五股下挫一美元胃學。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]'

In [16]:
len(train_loader)

8882

In [17]:
vocab = {}
for i, token in enumerate(tokenizer.convert_ids_to_tokens(range(0, len(tokenizer)))):
    vocab[i] = token

In [18]:
import json
with open('vocab.json', 'w', encoding='utf-8') as file:
     file.write(json.dumps(vocab, ensure_ascii=False))

# Model

In [19]:
# class BertModelWrapper(transformers.models.bert.modeling_bert.BertModel):
#
#     def __init__(self):
#         super =

In [20]:
class CSCModel(nn.Module):

    def __init__(self):
        super(CSCModel, self).__init__()

        self.semantic_encoder = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")

        self.output_layer = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        outputs = self.semantic_encoder(**inputs)
        return self.output_layer(outputs.last_hidden_state).squeeze(2)

In [21]:
model = CSCModel()
model = model.to(device)

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
model(inputs.to(device))

tensor([[0.4450, 0.4659, 0.4733,  ..., 0.4327, 0.4611, 0.4545],
        [0.5332, 0.5175, 0.4943,  ..., 0.5010, 0.5004, 0.4832],
        [0.4704, 0.5167, 0.4837,  ..., 0.4719, 0.4846, 0.4679],
        ...,
        [0.4940, 0.5044, 0.5484,  ..., 0.4661, 0.4622, 0.4756],
        [0.4834, 0.4944, 0.4990,  ..., 0.5122, 0.4710, 0.4663],
        [0.5026, 0.4950, 0.5046,  ..., 0.4993, 0.5040, 0.5091]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)

# Validation

In [23]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(inputs.to(device))

    segments = []
    last_index = 0
    for index in torch.where(outputs[0, 1:-1] >= 0.5)[0].tolist():
        segments.append(text[last_index:index])
        segments.append("\033[1;31m" + text[index] + "\033[0m")
        last_index = index + 1
    segments.append(text[last_index:])

    return ''.join(segments), (outputs[0, 1:-1] >= 0.5).int()

def validation():
    with open("data/test.sighan15.pkl", mode='br') as f:
        test_data = pickle.load(f)

    total_num = 0
    total_correct = 0

    total_recall_num = 0
    total_recall_correct = 0

    total_precision_num = 0
    total_precision_correct = 0

    prograss = tqdm(range(len(test_data)))
    for i in prograss:
        src, tgt = test_data[i]['src'], test_data[i]['tgt']

        _, output = predict(src)
        target = (torch.tensor(test_data[i]['src_idx']) != torch.tensor(test_data[i]['tgt_idx'])).int()[1:-1].to(device)

        if len(output) != len(target):
            print("第%d条数据异常，请检查" % i)
            break

        total_num += 1
        if (output != target).sum().item() == 0:
            total_correct += 1

        total_recall_correct += output[target == 1].sum().item()
        total_recall_num += (target == 1).sum().item()

        total_precision_num += (output == 1).sum().item()
        total_precision_correct += (target[output == 1]).sum().item()

        accuracy = total_correct / total_num
        recall = total_recall_correct / (total_recall_num + 1e-9)
        precision = total_precision_correct / (total_precision_num + 1e-9)

        prograss.set_postfix({
            'accuracy': accuracy,
            'recall': recall,
            'precision': precision
        })

In [24]:
validation()

100%|██████████| 1100/1100 [00:08<00:00, 132.53it/s, accuracy=0, recall=0.563, precision=0.0164]


# Train

In [25]:
model = model.train()

In [26]:
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [27]:
step = 0
total_loss = 0.
total_correct = 0
total_num = 0
total_correct_wrong_char = 0
total_wrong_char = 0

for epoch in range(n_epochs):

    for inputs_, targets_ in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1

        total_loss += loss.detach().item()
        total_correct += ((outputs >= 0.5).int() == targets.int()).sum().item()
        total_num += len(targets.flatten())

        # 统计成功预测错字的数量
        total_correct_wrong_char += (outputs >= 0.5)[targets == 1].sum().item()
        # 统计错字的数量
        total_wrong_char += (targets == 1).sum().item()

        if step % log_after_step == 0:
            print("Epoch {}, Step {}, loss {:.5f}, accuracy {:.4f}, recall {:.4f}".format(epoch, step,
                                                                                          total_loss / log_after_step,
                                                                                          total_correct / total_num,
                                                                                          total_correct_wrong_char / total_wrong_char))
            total_loss = 0.
            total_correct = 0
            total_num = 0
            total_correct_wrong_char = 0
            total_wrong_char = 0

        if step % valid_after_step == 0:
            validation()

Epoch 0, Step 50, loss 0.07968, accuracy 0.9764, recall 0.7198
Epoch 0, Step 100, loss 0.00552, accuracy 0.9991, recall 0.9976
Epoch 0, Step 150, loss 0.00144, accuracy 0.9997, recall 0.9995
Epoch 0, Step 200, loss 0.00041, accuracy 0.9999, recall 0.9998


100%|██████████| 1100/1100 [00:11<00:00, 98.71it/s, accuracy=0.505, recall=0.319, precision=0.418] 


Epoch 0, Step 250, loss 0.00019, accuracy 1.0000, recall 1.0000
Epoch 0, Step 300, loss 0.00868, accuracy 0.9985, recall 0.9871
Epoch 0, Step 350, loss 0.00024, accuracy 1.0000, recall 1.0000
Epoch 0, Step 400, loss 0.00017, accuracy 1.0000, recall 0.9999


100%|██████████| 1100/1100 [00:09<00:00, 113.81it/s, accuracy=0.515, recall=0.201, precision=0.464]


KeyboardInterrupt: 

# Inference

In [None]:
model = torch.load('csc-model.pt', map_location='cpu')

In [None]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(inputs.to(device))

    segments = []
    last_index = 0
    for index in torch.where(outputs[0, 1:-1] >= 0.5)[0].tolist():
        segments.append(text[last_index:index])
        segments.append("\033[1;31m" + text[index] + "\033[0m")
        last_index = index + 1
    segments.append(text[last_index:])

    return ''.join(segments), (outputs[0, 1:-1] >= 0.5).int()

In [None]:
text, output = predict("昨天下雨了你紫道吗，但是有些词确实太男了")
print(text)
print(output)

In [None]:
text

# Evaluation

In [None]:
with open("data/test.sighan15.pkl", mode='br') as f:
    test_data = pickle.load(f)

In [None]:
total_num = 0
total_correct = 0

total_recall_num = 0
total_recall_correct = 0

total_precision_num = 0
total_precision_correct = 0

prograss = tqdm(range(len(test_data)))
for i in prograss:
    src, tgt = test_data[i]['src'], test_data[i]['tgt']

    _, output = predict(src)
    target = (torch.tensor(test_data[i]['src_idx']) != torch.tensor(test_data[i]['tgt_idx'])).int()[1:-1].to(device)

    if len(output) != len(target):
        print("第%d条数据异常" % i)


    total_num += 1
    if (output != target).sum().item() == 0:
        total_correct += 1

    total_recall_correct += output[target == 1].sum().item()
    total_recall_num += (target == 1).sum().item()

    total_precision_num += (output == 1).sum().item()
    total_precision_correct += (target[output == 1]).sum().item()

    accuracy = total_correct / total_num
    recall = total_recall_correct / (total_recall_num + 1e-9)
    precision = total_precision_correct / (total_precision_num + 1e-9)

    prograss.set_postfix({
        'accuracy': accuracy,
        'recall': recall,
        'precision': precision
    })