# Environment

In [158]:
try:
    import transformers
except:
    !pip install transformers

In [159]:
import transformers
import pickle

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

In [160]:
transformers.__version__

'4.21.2'

# Global Config

In [242]:
max_length = 128
batch_size = 32

log_after_step = 20

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data

In [162]:
!gdown '1dC09i57lobL91lEbpebDuUBS0fGz-LAk' --folder --output data

'gdown' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


In [163]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

In [164]:
sentence = "昨天下雨了，天气非常凉爽，今天我们出门去玩吧，就去昨天那个地方。"

In [192]:
class CSCDataset(Dataset):

    def __init__(self):
        super(CSCDataset, self).__init__()
        with open("data/trainall.times2.pkl", mode='br') as f:
            train_data = pickle.load(f)

        self.train_data = train_data

    def __getitem__(self, index):
        tgt = self.train_data[index]['tgt']
        return tgt

    def __len__(self):
        return len(self.train_data)

In [193]:
train_set = CSCDataset()

In [194]:
train_set.__getitem__(10)

'据科技谘询公司说，公司的执行长班哈姆对掌上型电脑市场表现感到满意，今年第一季它的市场表现较去年下降百分之十二。'

# DataLoader

In [195]:
def collate_fn(batch):
    text = list(batch)

    inputs = tokenizer(text, padding='max_length', max_length=max_length, return_tensors='pt', truncation=True)
    targets = inputs['input_ids']

    return inputs, targets

In [196]:
train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_fn)

In [197]:
inputs, targets = next(iter(train_loader))

# Model

In [198]:
class CopyModel(nn.Module):

    def __init__(self, max_length=128):
        super(CopyModel, self).__init__()

        self.max_length = max_length

        self.bert = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")

        self.interlayer = nn.Sequential(
            nn.Linear(max_length * 768, 2048),
            nn.ReLU(),
            nn.Linear(2048, max_length * 768),
            nn.ReLU(),
        )

        decoder_layer = nn.TransformerDecoderLayer(d_model=768, nhead=12, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)

        self.predictor = nn.Sequential(
            nn.Linear(768, 1024),
            nn.ReLU(),
            nn.Linear(1024, len(tokenizer))
            # nn.Softmax(dim=2)
        )

    def forward(self, inputs):
        outputs = self.bert(**inputs)['last_hidden_state']

        outputs = self.interlayer(outputs.view(-1, self.max_length * 768))
        outputs = outputs.view(-1, self.max_length, 768)

        # n_tokens = outputs.size(1)
        # tgt = torch.ones(1, n_tokens, 768)
        outputs = self.decoder(outputs, outputs)
        outputs = self.predictor(outputs)

        return outputs

In [199]:
model = CopyModel().to(device)

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [200]:
outputs = model(inputs)
outputs.size()

torch.Size([32, 128, 21128])

# Training

In [243]:
criteria = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
model = model.train()

In [245]:
total_loss = 0.
step = 0

for epoch in range(10):
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs.view(-1, len(tokenizer)), targets.view(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

        step += 1

        if step % log_after_step == 0:
            print("step {}, loss {:.4f}".format(step, total_loss / log_after_step))

step 1, loss 1.6726


KeyboardInterrupt: 

# Inference

In [247]:
sentence = "纽约早盘作为基准的低硫轻油，五越份交割价攀升一点三四人民币，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。"

In [248]:
inputs = tokenizer(sentence, return_tensors='pt', max_length=128, padding='max_length', add_special_tokens=True,
                   truncation=True)

In [249]:
outputs = model(inputs.to(device))

In [250]:
"".join(tokenizer.convert_ids_to_tokens(outputs.argmax(-1)[0]))

'[CLS]纽约早盘作为基准的低硫轻油，五月份交割价攀升一点三四美元，来到每桶二十八点二五美元，而上周五曾下挫一美元以上。[SEP]南关说元国为议日以过会会有十的上会这十元年行国行表国，，会二总表国国会，不成十南国人的年表人人二说特数会表的年日日国行来美二会国哈的会员不十的'