In [1]:
# !pip install jieba sentencepiece transformers

In [2]:
import os
import pickle
import random

import numpy as np
from tqdm import tqdm
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import T5ForConditionalGeneration
from tokenization_enc_dec import EncDecTokenizer

In [3]:
class Dataset(torch.utils.data.IterableDataset):
    def __init__(self, batch_size = 16, max_len=50):
        super(Dataset).__init__()
        self.batch_size = batch_size
        
        data = []
        data_root = '../chatdata/'
        for f in os.listdir(data_root):
            data += pickle.load(open(os.path.join(data_root, f), 'rb'))
        self.data = data

    def random_data(self):
        one_data = random.choice(self.data)
        lens = random.randint(2, 5)
        ind = random.randint(0, len(one_data) - lens - 1)
        sentences = one_data[ind:ind+lens]
        maxlen = max([len(x) for x in sentences])
        if maxlen > 40:
            return self.random_data()
        s = ''.join(sentences)
        if '拜拜' in s or '再见' in s:
            if random.random() < 0.95:
                return self.random_data()
        return sentences

    def get_single_data(self, sentences):
        input_ids = []
        for s in sentences[:-1]:
            input_ids += tokenizer.encode(s) + [tokenizer.sep_id]
        input_ids += [tokenizer.get_sentinel_id(0)]
        y_ids = tokenizer.encode(sentences[-1])
        decoder_input_ids = [tokenizer.get_sentinel_id(0)] + y_ids
        labels = y_ids + [tokenizer.sep_id]
        return torch.LongTensor(input_ids), torch.LongTensor(decoder_input_ids), torch.LongTensor(labels)
        
    def __iter__(self):
        batch_size = self.batch_size
        batch = []
        keep = np.random.randint(2, 6)
        while True:
            ids, dids, lbl = self.get_single_data(self.random_data())
            batch.append((ids, dids, lbl))
            if len(batch) >= batch_size:
                input_ids = pad_sequence([
                    x[0]
                    for x in batch
                ], batch_first=True, padding_value=tokenizer.pad_id)
                mask = (input_ids != tokenizer.pad_id).to(input_ids.dtype)
                decoder_input_ids = pad_sequence([
                    x[1]
                    for x in batch
                ], batch_first=True, padding_value=tokenizer.pad_id)
                decoder_mask = (decoder_input_ids != tokenizer.pad_id).to(input_ids.dtype)
                # padding -100是源代码里面的magic number， 参考：
                # https://github.com/huggingface/transformers/blob/1c06240e1b3477728129bb58e7b6c7734bb5074e/src/transformers/models/t5/modeling_t5.py#L1580
                labels = pad_sequence([
                    x[2]
                    for x in batch
                ], batch_first=True, padding_value=-100)
                yield input_ids, mask, decoder_input_ids, decoder_mask, labels
                batch = []
                keep = np.random.randint(2, 6)

In [4]:
%%time
model = T5ForConditionalGeneration.from_pretrained('./torch_eva/')
tokenizer = EncDecTokenizer('./EVA/src/bpe_dialog_new/vocab.txt')

CPU times: user 1min 18s, sys: 8.34 s, total: 1min 27s
Wall time: 27.5 s


In [5]:
%%time
ds = Dataset(batch_size=16)
dl = torch.utils.data.DataLoader(ds, num_workers=8, batch_size=None, pin_memory=True, prefetch_factor=10)

CPU times: user 1.62 s, sys: 887 ms, total: 2.5 s
Wall time: 2.52 s


In [10]:
%%time
fp16 = True
cuda = True

if fp16:
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=1e-3,
        momentum=0.9
    )
else:
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=1e-3,
        momentum=0.9
    )
if cuda:
    if fp16:
        model = model.half().cuda()
    else:
        model = model.cuda()

CPU times: user 1min 32s, sys: 41.9 s, total: 2min 14s
Wall time: 5.55 s


In [11]:
step = 0

In [12]:
from torch.nn import CrossEntropyLoss

In [13]:
optimizer.zero_grad()
losses = []
pbar = tqdm(dl)
loss_fct2 = CrossEntropyLoss(reduction='none')

for x, m0, y, m1, z in pbar:
    if cuda:
        x = x.cuda()
        y = y.cuda()
        z = z.cuda()
        m0 = m0.cuda()
        m1 = m1.cuda()
    with torch.cuda.amp.autocast():
        out = model(
            input_ids=x,
            attention_mask=m0,
            decoder_input_ids=y,
            decoder_attention_mask=m1,
            # labels=z
        )
        # loss = out.loss
        # loss = loss_fct(out.logits.view(-1, out.logits.size(-1)), z.view(-1))
        loss2 = loss_fct2(out.logits.view(-1, out.logits.size(-1)), z.view(-1))
        loss = torch.sum(loss2 * (z.view(-1) >= 0)) / torch.sum(z.view(-1) >= 0)

    loss.backward()
    if torch.isnan(loss) or torch.isinf(loss):
        print('bad loss')
        optimizer.zero_grad()
        continue
    else:
        optimizer.step()
        optimizer.zero_grad()
    loss = loss.detach().cpu().numpy()
    losses.append(loss)
    losses = losses[-100:]
    pbar.set_description(f'step: {step} loss: {np.mean(losses):.4f}')
    step += 1
    if step > 0 and step % (60 * 60) == 0:
        print('save', step, loss)
        torch.save(model.state_dict(), f'model_{step}.pt')
        torch.save(optimizer.state_dict(), f'opt_{step}.pt')

0it [00:00, ?it/s]Building prefix dict from the default dictionary ...
Building prefix dict from the default dictionary ...
Building prefix dict from the default dictionary ...Building prefix dict from the default dictionary ...Building prefix dict from the default dictionary ...
Building prefix dict from the default dictionary ...Building prefix dict from the default dictionary ...
Building prefix dict from the default dictionary ...Loading model from cache /tmp/jieba.cache


Loading model from cache /tmp/jieba.cache

Loading model from cache /tmp/jieba.cacheLoading model from cache /tmp/jieba.cacheLoading model from cache /tmp/jieba.cacheLoading model from cache /tmp/jieba.cache
Loading model from cache /tmp/jieba.cacheLoading model from cache /tmp/jieba.cache





Loading model cost 0.836 seconds.Loading model cost 0.835 seconds.

Prefix dict has been built successfully.Prefix dict has been built successfully.Loading model cost 0.838 seconds.


Prefix dict has been built successfull

save 3600 1.8041242


step: 7199 loss: 1.5582: : 7199it [50:23,  2.30it/s]

save 7200 1.6094422


step: 10799 loss: 1.4074: : 10799it [1:15:44,  2.29it/s]

save 10800 1.3296105


step: 14399 loss: 1.3315: : 14399it [1:40:58,  2.44it/s]

save 14400 1.0787582


step: 17999 loss: 1.2558: : 17999it [2:06:21,  2.37it/s]

save 18000 1.3367486


step: 21599 loss: 1.2408: : 21599it [2:31:45,  2.32it/s]

save 21600 1.1896592


step: 25199 loss: 1.2122: : 25199it [2:57:02,  2.41it/s]

save 25200 1.5990603


step: 28799 loss: 1.2051: : 28799it [3:22:24,  2.33it/s]

save 28800 1.0319228


step: 32399 loss: 1.1750: : 32399it [3:47:44,  2.40it/s]

save 32400 1.0566983


step: 35999 loss: 1.1745: : 35999it [4:13:10,  2.33it/s]

save 36000 1.2732273


step: 39599 loss: 1.1713: : 39599it [4:38:40,  2.40it/s]

save 39600 1.0271161


step: 43199 loss: 1.1535: : 43199it [5:04:06,  2.46it/s]

save 43200 1.1684688


step: 46799 loss: 1.1566: : 46799it [5:29:34,  2.44it/s]

save 46800 1.0300099


step: 50399 loss: 1.1335: : 50399it [5:54:59,  2.44it/s]

save 50400 1.1262554


step: 53999 loss: 1.1372: : 53999it [6:20:24,  2.49it/s]

save 54000 0.97241783


step: 57599 loss: 1.1314: : 57599it [6:45:49,  2.40it/s]

save 57600 1.044987


step: 61199 loss: 1.1301: : 61199it [7:11:12,  2.46it/s]

save 61200 1.1718122


step: 64799 loss: 1.1209: : 64799it [7:36:39,  2.38it/s]

save 64800 1.1101009


step: 68399 loss: 1.1219: : 68399it [8:02:02,  2.35it/s]

save 68400 1.236889


step: 71999 loss: 1.0911: : 71999it [8:27:25,  2.33it/s]

save 72000 0.91616446


step: 75599 loss: 1.1008: : 75599it [8:52:55,  2.44it/s]

save 75600 1.0377946


step: 79199 loss: 1.1044: : 79199it [9:18:20,  2.42it/s]

save 79200 1.0226601


step: 82799 loss: 1.0990: : 82799it [9:43:49,  2.39it/s]

save 82800 1.1966805


step: 86399 loss: 1.0968: : 86399it [10:09:20,  2.48it/s]

save 86400 1.1003646


step: 89999 loss: 1.0721: : 89999it [10:34:48,  2.24it/s]

save 90000 1.1350629


step: 93599 loss: 1.1011: : 93599it [11:00:19,  2.46it/s]

save 93600 1.0899584


step: 97199 loss: 1.0903: : 97199it [11:25:49,  2.38it/s]

save 97200 1.1363858


step: 100799 loss: 1.0771: : 100799it [11:51:12,  2.31it/s]

save 100800 1.0004612


step: 104399 loss: 1.0786: : 104399it [12:16:33,  2.38it/s]

save 104400 1.0684063


step: 107999 loss: 1.0796: : 107999it [12:42:07,  2.37it/s]

save 108000 1.0564996


step: 111599 loss: 1.0763: : 111599it [13:07:33,  2.25it/s]

save 111600 0.9773707


step: 115199 loss: 1.0806: : 115199it [13:33:01,  2.39it/s]

save 115200 1.0635233


step: 118799 loss: 1.0520: : 118799it [13:58:28,  2.34it/s]

save 118800 0.93359685


step: 122399 loss: 1.0879: : 122399it [14:23:57,  2.37it/s]

save 122400 1.141929


step: 125999 loss: 1.0746: : 125999it [14:49:28,  2.36it/s]

save 126000 1.0127586


step: 129599 loss: 1.0557: : 129599it [15:14:53,  2.38it/s]

save 129600 1.1555089


step: 133199 loss: 1.0515: : 133199it [15:40:21,  2.37it/s]

save 133200 0.9401102


step: 136799 loss: 1.0636: : 136799it [16:05:49,  2.30it/s]

save 136800 1.0743293


step: 140399 loss: 1.0585: : 140399it [16:31:20,  2.44it/s]

save 140400 1.0541804


step: 143999 loss: 1.0454: : 143999it [16:56:49,  2.40it/s]

save 144000 0.95395374


step: 147599 loss: 1.0531: : 147599it [17:22:16,  2.43it/s]

save 147600 0.9652821


step: 151199 loss: 1.0582: : 151199it [17:47:42,  2.31it/s]

save 151200 0.9304118


step: 154799 loss: 1.0622: : 154799it [18:13:02,  2.23it/s]

save 154800 1.2611653


step: 158399 loss: 1.0580: : 158399it [18:38:31,  2.45it/s]

save 158400 1.0977086


step: 161999 loss: 1.0358: : 161999it [19:03:54,  2.39it/s]

save 162000 0.99900043


step: 165599 loss: 1.0342: : 165599it [19:29:13,  2.44it/s]

save 165600 0.9897446


step: 169199 loss: 1.0346: : 169199it [19:54:42,  2.37it/s]

save 169200 0.97383714


step: 172799 loss: 1.0518: : 172799it [20:20:06,  2.44it/s]

save 172800 1.0148427


step: 176399 loss: 1.0192: : 176399it [20:45:34,  2.40it/s]

save 176400 1.0250503


step: 176750 loss: 1.0093: : 176751it [20:48:24,  2.36it/s]


KeyboardInterrupt: 