# Environment

In [1]:
try:
    import transformers
except:
    !pip install transformers

In [2]:
import time
import os
import pickle

import transformers
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

In [3]:
transformers.__version__

'4.21.3'

# Global Config

In [4]:
max_length = 128
batch_size = 32
epochs = 10

log_after_step = 20

model_path = './drive/MyDrive/models/'
os.mkdir(model_path) if not os.path.exists(model_path) else ''
model_path = model_path + 'csc-model.pt'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
test_mode = False
if test_mode:
    epochs = 1000
    data_length = 10
    batch_size = 1
    log_after_step = 1
    model_path = 'csc-model.pt'

In [6]:
!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


# Data

In [7]:
class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        src = self.train_data[index]['src']
        tgt = self.train_data[index]['tgt']
        return src, tgt

    def __len__(self):
        if test_mode:
            return data_length
        return len(self.train_data)

In [8]:
train_data = CSCDataset()

In [9]:
train_data.__getitem__(0)

('纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。')

## Dataloader

In [10]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [11]:
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    src = tokenizer(src, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)
    tgt = tokenizer(tgt, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)

    return src, (src['input_ids'] != tgt['input_ids']).float()

In [12]:
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)

In [13]:
inputs, targets = next(iter(train_loader))

# Model

In [14]:
class CSCModel(nn.Module):

    def __init__(self):
        super(CSCModel, self).__init__()

        self.semantic_encoder = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
        self.word_embeddings = self.semantic_encoder.get_input_embeddings()

        transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=768 * 3, nhead=6, dim_feedforward=1024,
                                                      activation='gelu',
                                                      batch_first=True)
        self.transformer = nn.TransformerEncoder(transformer_encoder_layer, num_layers=1)

        self.output_layer = nn.Sequential(
            nn.Linear(768 * 3, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        token_num = inputs['input_ids'].size(1)
        outputs = self.semantic_encoder(**inputs)
        word_embeddings = self.word_embeddings(inputs['input_ids'])
        cls_outputs = outputs.last_hidden_state[:, 0:1, :].repeat(1, token_num, 1)
        outputs = torch.concat([outputs.last_hidden_state, word_embeddings, cls_outputs], dim=2)
        outputs = self.transformer(outputs)
        return self.output_layer(outputs).squeeze(2) * inputs['attention_mask']

In [15]:
test_model = CSCModel()
test_model(inputs).size()

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([32, 128])

# Train

In [16]:
model = CSCModel()
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
start_epoch = 0
total_step = 0
record = []  # 记录loss、accuracy变化等

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
if os.path.exists(model_path):
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch']
    total_step = checkpoint['total_step']
    record = checkpoint['record']
    print("恢复训练，epoch:", start_epoch)

In [18]:
model = model.to(device)
model = model.train()

In [19]:
writer = SummaryWriter(log_dir='runs/csc_model')
# 恢复之前的数据
for item in record:
    step, loss, recall, precision = item
    writer.add_scalar(tag="record/loss", scalar_value=loss, global_step=step)
    writer.add_scalar(tag="record/recall", scalar_value=recall, global_step=step)
    writer.add_scalar(tag="record/precision", scalar_value=precision, global_step=step)

In [20]:
if not test_mode:
    %load_ext tensorboard
    %tensorboard --logdir=runs
    time.sleep(10)

Launching TensorBoard...

In [21]:
total_loss = 0.
total_correct = 0
total_num = 0
total_correct_wrong_char = 0
total_wrong_char = 0
total_precision_num = 0
total_precision_correct = 0

for epoch in range(start_epoch, epochs):

    step = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1
        total_step += 1

        total_loss += loss.detach().item()
        total_correct += ((outputs >= 0.5).int() == targets.int()).sum().item()
        total_num += len(targets.flatten())

        predicts = outputs >= 0.5
        # 统计成功预测错字的数量
        total_correct_wrong_char += predicts[targets == 1].sum().item()
        # 统计错字的数量
        total_wrong_char += (targets == 1).sum().item()

        total_precision_num += predicts.sum().item()
        total_precision_correct += (targets[predicts == 1]).sum().item()

        if total_step % log_after_step == 0:
            loss = total_loss / log_after_step
            recall = total_correct_wrong_char / (total_wrong_char + 1e-9)
            precision = total_precision_correct / (total_precision_num + 1e-9)

            print("Epoch {}, "
                  "Step {}/{}, "
                  "Total Step {}, "
                  "loss {:.5f}, "
                  "recall {:.4f}, "
                  "precision {:.4f}".format(epoch, step, len(train_loader), total_step,
                                            loss,
                                            recall,
                                            precision))
            writer.add_scalar(tag="record/loss", scalar_value=loss, global_step=total_step)
            writer.add_scalar(tag="record/recall", scalar_value=recall, global_step=total_step)
            writer.add_scalar(tag="record/precision", scalar_value=precision, global_step=total_step)

            record.append((total_step, loss, recall, precision,))

            total_loss = 0.
            total_correct = 0
            total_num = 0
            total_correct_wrong_char = 0
            total_wrong_char = 0
            total_precision_num = 0
            total_precision_correct = 0

    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch + 1,
        'total_step': total_step,
        'record': record
    }, model_path)


Epoch 0, Step 20/8882, Total Step 20, loss 0.08239, accuracy 0.9781, recall 0.0259, precision 0.0261
Epoch 0, Step 40/8882, Total Step 40, loss 0.02361, accuracy 0.9920, recall 0.4467, precision 0.7057
Epoch 0, Step 60/8882, Total Step 60, loss 0.01509, accuracy 0.9954, recall 0.7114, precision 0.8413
Epoch 0, Step 80/8882, Total Step 80, loss 0.01395, accuracy 0.9954, recall 0.7187, precision 0.8396
Epoch 0, Step 100/8882, Total Step 100, loss 0.01261, accuracy 0.9958, recall 0.7441, precision 0.8667
Epoch 0, Step 120/8882, Total Step 120, loss 0.01139, accuracy 0.9960, recall 0.7582, precision 0.8667


KeyboardInterrupt: 

# Inference

In [25]:
if os.path.exists(model_path):
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model'])
    model = model.to(device)

In [27]:
model = model.eval()

In [30]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(inputs.to(device))

    segments = []
    last_index = 0
    for index in torch.where(outputs[0, 1:-1] >= 0.5)[0].tolist():
        segments.append(text[last_index:index])
        segments.append("\033[1;31m" + text[index] + "\033[0m")
        last_index = index + 1
    segments.append(text[last_index:])

    return ''.join(segments), (outputs[0, 1:-1] >= 0.5).int()

In [32]:
text, output = predict("我昨天吃了一个火聋果")
print(text)
print(output)

怎么越来越[1;31m才[0m了
tensor([0, 0, 0, 0, 0, 1, 0], device='cuda:0', dtype=torch.int32)


# Evaluation

In [33]:
with open("data/test.sighan15.pkl", mode='br') as f:
    test_data = pickle.load(f)

In [34]:
total_num = 0
total_correct = 0

total_recall_num = 0
total_recall_correct = 0

total_precision_num = 0
total_precision_correct = 0

prograss = tqdm(range(len(test_data)))
for i in prograss:
    src, tgt = test_data[i]['src'], test_data[i]['tgt']

    _, output = predict(src)
    target = (torch.tensor(test_data[i]['src_idx']) != torch.tensor(test_data[i]['tgt_idx'])).int()[1:-1].to(device)

    if len(output) != len(target):
        print("第%d条数据异常" % i)

    total_num += 1
    if (output != target).sum().item() == 0:
        total_correct += 1

    total_recall_correct += output[target == 1].sum().item()
    total_recall_num += (target == 1).sum().item()

    total_precision_num += (output == 1).sum().item()
    total_precision_correct += (target[output == 1]).sum().item()

    recall = total_recall_correct / (total_recall_num + 1e-9)
    precision = total_precision_correct / (total_precision_num + 1e-9)

    prograss.set_postfix({
        'recall': recall,
        'precision': precision
    })

100%|██████████| 1100/1100 [00:09<00:00, 112.68it/s, accuracy=0.695, recall=0.656, precision=0.725]
