# Environment

In [1]:
!pip install transformers == 4.21.2



You should consider upgrading via the 'D:\Anaconda3\python.exe -m pip install --upgrade pip' command.





In [2]:
import transformers
import pickle

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

In [3]:
transformers.__version__

'4.21.2'

# Global Config

In [4]:
max_length = 128
batch_size = 32

log_after_step = 20

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


# Data

In [55]:
class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        src = self.train_data[index]['src']
        tgt = self.train_data[index]['tgt']
        return src, tgt

    def __len__(self):
        return len(self.train_data)

In [56]:
train_data = CSCDataset()

In [57]:
train_data.__getitem__(0)

('纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。',
 '纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。')

## Dataloader

In [58]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [59]:
def collate_fn(batch):
    src, tgt = zip(*batch)
    src, tgt = list(src), list(tgt)

    src = tokenizer(src, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)
    tgt = tokenizer(tgt, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)

    return src, (src['input_ids'] != tgt['input_ids']).float()

In [60]:
dataloader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn)

In [61]:
inputs, targets = next(iter(dataloader))

In [62]:
targets

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0

# Model

In [63]:
class CSCModel(nn.Module):

    def __init__(self):
        super(CSCModel, self).__init__()

        self.semantic_encoder = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")

        self.output_layer = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        outputs = self.semantic_encoder(**inputs)
        return self.output_layer(outputs.last_hidden_state).squeeze(2)

# Train

In [64]:
model = CSCModel()
model = model.to(device)
model = model.train()

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [65]:
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [66]:
step = 0
total_loss = 0.
total_correct = 0
total_num = 0
total_correct_wrong_char = 0
total_wrong_char = 0

for epoch in range(10):

    for inputs_, targets_ in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1

        total_loss += loss.detach().item()
        total_correct += ((outputs >= 0.5).int() == targets.int()).sum().item()
        total_num += len(targets.flatten())

        # 统计成功预测错字的数量
        total_correct_wrong_char += (outputs >= 0.5)[targets == 1].sum().item()
        # 统计错字的数量
        total_wrong_char += (targets == 1).sum().item()

        if step % log_after_step == 0:
            print("Epoch {}, Step {}, loss {:.5f}, accuracy {:.4f}, recall {:.4f}".format(epoch, step,
                                                                                          total_loss / log_after_step,
                                                                                          total_correct / total_num,
                                                                                          total_correct_wrong_char / total_wrong_char))
            total_loss = 0.
            total_correct = 0
            total_num = 0
            total_correct_wrong_char = 0
            total_wrong_char = 0


Epoch 0, Step 20, loss 0.09453, accuracy 0.9856, recall 0.1500
Epoch 0, Step 40, loss 0.00275, accuracy 1.0000, recall 1.0000


KeyboardInterrupt: 

In [18]:
text = "纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。"

In [19]:
inputs = tokenizer(text, return_tensors='pt')

In [20]:
outputs = model(inputs.to(device))

In [21]:
for index in torch.where(outputs[0, 1:-1] >= 0.5)[0].tolist():
    text = text[:index] + "\033[1;31m" + text[index] + "\033[0m" + text[index + 1:]

In [22]:
print(text)

纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一豪元以上。
