In [1]:
import re
from collections import OrderedDict

import torch
from torch.nn import functional as F
import numpy as np
from tqdm import tqdm
from transformers import T5EncoderModel
from transformers import T5Config

In [2]:
from tokenization_enc_dec import EncDecTokenizer
tokenizer = EncDecTokenizer('./vocab.txt')

In [16]:
class EncoderModel(torch.nn.Module):
    def __init__(self, config):
        super(EncoderModel, self).__init__()
        self.encoder = T5EncoderModel(config)
        self.out = torch.nn.Linear(4096, 512)

    def forward(self, input_ids, mask):
        x = self.encoder(input_ids, mask).last_hidden_state
        x = x[:, 0, :]
        x = self.out(x)
        x = F.normalize(x, p=2, dim=-1)
        return x

class TrainModel(torch.nn.Module):
    def __init__(self, encoder):
        super(TrainModel, self).__init__()
        self.encoder = encoder
        self.proj = torch.nn.Linear(512 * 4, 2)

    def forward(self, input_ids_0, mask_0, input_ids_1, mask_1):
        a = self.encoder(input_ids_0, mask_0)
        b = self.encoder(input_ids_1, mask_1)

        c = torch.cat([
            a,
            b,
            torch.abs(a - b),
            a + b
        ], dim=-1)
        c = self.proj(c)
        
        sim = 1 - torch.arccos(
            # 切到-1.0和1.0可能导致fp16下的NaN
            torch.clip(
                torch.sum(a * b, 1),
                -0.99,
                0.99
            )
        ) / 3.1415
        
        return c, sim

In [4]:
import os
import random
from tqdm import tqdm
import torch
from tokenization_enc_dec import EncDecTokenizer
tokenizer = EncDecTokenizer('./vocab.txt')


class Dataset(torch.utils.data.IterableDataset):
    def __init__(self, batch_size=16, data_root = '../sts_dataset'):
        self.batch_size = batch_size
        lines = []
        for f in tqdm([os.path.join(data_root, f) for f in os.listdir(data_root)]):
            lines += open(f).read().split('\n')
        data = []
        bad = []
        for l in tqdm(lines):
            s = l.split('\t')
            if len(s) >= 3 and s[-1] in ('0', '1') and len(s[-3].strip()) >= 1 and len(s[-2].strip()) >= 1:
                y = int(s[-1])
#                 if y < 1:
#                     y = -1
                data.append((s[-3], s[-2], y))
            else:
                bad.append(s)
        random.shuffle(data)
        self.data = data
        print(len(data))

    def __iter__(self):
        batch_size = self.batch_size
        batch = []
        random.shuffle(self.data)
        while True:
            for item in self.data:
                batch.append(item)
                if len(batch) >= batch_size:
                    x0 = torch.nn.utils.rnn.pad_sequence([
                        torch.LongTensor(
                            [1] + tokenizer.encode(
                                x[0]
                            )
                        )
                        for x in batch
                    ], batch_first=True, padding_value=tokenizer.pad_id)
                    x1 = torch.nn.utils.rnn.pad_sequence([
                        torch.LongTensor(
                            [1] + tokenizer.encode(x[1])
                        )
                        for x in batch
                    ], batch_first=True, padding_value=tokenizer.pad_id)
                    y = torch.LongTensor([x[2] for x in batch])
                    m0 = (x0 != tokenizer.pad_id).to(torch.int64)
                    m1 = (x1 != tokenizer.pad_id).to(torch.int64)
                    yield x0, m0, x1, m1, y
                    batch = []

In [5]:
config = T5Config(
    vocab_size=26240,
#     n_positions=self.n_positions,
    d_model=4096,
    d_ff=10240,
    d_kv=4096 // 64,
    num_layers=2,
    num_heads=64,
    relative_attention_num_buckets=32,
    dropout_rate=0.0,
    initializer_factor=1.0,
    eos_token_id=tokenizer.eod_id,
    bos_token_id=tokenizer.pad_id,
    pad_token_id=tokenizer.pad_id,
    decoder_start_token_id=tokenizer.pad_id,
    feed_forward_proj='gated-gelu',
    tie_word_embeddings=False
)

In [None]:
state_dict = torch.load('../cpm-2.1-encoder.pt')

In [26]:
model = EncoderModel(config)

In [27]:
new_state_dict = {
    k: v
    for k, v in state_dict.items()
    if 'encoder.block.' not in k or (
        'encoder.block.0.' in k or 'encoder.block.1.' in k
    )
}
new_state_dict['encoder.final_layer_norm.weight'] = model.encoder.state_dict()['encoder.final_layer_norm.weight']
model.encoder.load_state_dict(new_state_dict)

<All keys matched successfully>

In [28]:
train_model = TrainModel(model)

In [29]:
%%time
ds = Dataset(batch_size=64)
dl = torch.utils.data.DataLoader(ds, num_workers=4, batch_size=None, pin_memory=True, prefetch_factor=10)

100%|██████████| 3/3 [00:00<00:00,  5.65it/s]
100%|██████████| 980227/980227 [00:01<00:00, 521399.29it/s]


980227
CPU times: user 3.28 s, sys: 271 ms, total: 3.55 s
Wall time: 3.55 s


In [30]:
%%time
fp16 = True
cuda = True

if cuda:
    if fp16:
        train_model = train_model.half().cuda()
    else:
        train_model = train_model.cuda()

CPU times: user 22.9 s, sys: 4.45 s, total: 27.3 s
Wall time: 435 ms


In [31]:
if fp16:
    optimizer = torch.optim.SGD(
        train_model.parameters(),
        lr=5e-3,
        momentum=0.9
    )
else:
    optimizer = torch.optim.SGD(
        train_model.parameters(),
        lr=5e-3,
        momentum=0.9
    )

In [32]:
loss_fc_0 = torch.nn.NLLLoss()
loss_fc_1 = torch.nn.MSELoss()
m = torch.nn.LogSoftmax(dim=1)

In [33]:
step = 0
losses = []
losses0 = []
losses1 = []

In [35]:
pbar = tqdm(dl)

for x0, m0, x1, m1, y in pbar:

    optimizer.zero_grad()

    if cuda:
        x0 = x0.cuda()
        m0 = m0.cuda()
        x1 = x1.cuda()
        m1 = m1.cuda()
        y = y.cuda()
    with torch.cuda.amp.autocast():
        c, sim = train_model(x0, m0, x1, m1)
        l0 = loss_fc_0(m(c), y)
        l1 = loss_fc_1(sim, y.to(sim.dtype))
        loss = l0 + l1

    loss.backward()
    optimizer.step()
    l0 = l0.detach().cpu().numpy()
    l1 = l1.detach().cpu().numpy()
    loss = loss.detach().cpu().numpy()
    losses.append(loss)
    losses = losses[-100:]
    losses0.append(l0)
    losses0 = losses0[-100:]
    losses1.append(l1)
    losses1 = losses1[-100:]
    pbar.set_description(f'step: {step} loss: {np.mean(losses):.4f} l0: {np.mean(losses0):.4f} l1: {np.mean(losses1):.4f}')
    step += 1
    if step > 0 and step % (60 * 60) == 0:
        print('save', step, np.mean(losses))
        torch.save(model.state_dict(), f'model_{step}.pt')
        torch.save(optimizer.state_dict(), f'opt_{step}.pt')

0it [00:00, ?it/s]Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Building prefix dict from the default dictionary ...Building prefix dict from the default dictionary ...Building prefix dict from the default dictionary ...


Loading model from cache /tmp/jieba.cacheLoading model from cache /tmp/jieba.cacheLoading model from cache /tmp/jieba.cache


Loading model cost 0.879 seconds.
Prefix dict has been built successfully.
Loading model cost 0.880 seconds.Loading model cost 0.880 seconds.

Prefix dict has been built successfully.Prefix dict has been built successfully.

Loading model cost 0.886 seconds.
Prefix dict has been built successfully.
step: 94337 loss: 0.4264 l0: 0.2368 l1: 0.1896: : 368it [01:26,  4.24it/s]


KeyboardInterrupt: 

In [None]:
torch.onnx.export(
    model,
    (e0, m0),
    './model.onnx',
    opset_version=13,
    input_names=[
        "encoder_outputs", "mask",
    ],
    output_names=["normalized_vector"],
    dynamic_axes={
        "encoder_outputs": {0: "batch", 1: "sequence"},
        "mask": {0: "batch", 1: "sequence"},
        "normalized_vector": {0: "batch", 1: "sequence"},
    },
)

In [None]:
# !du -sh './model.onnx'