In [203]:
import re
from collections import OrderedDict

import torch
import numpy as np
from tqdm import tqdm
from transformers import T5ForConditionalGeneration
from transformers import T5Config

from tokenization_enc_dec import EncDecTokenizer

In [201]:
tokenizer = EncDecTokenizer('./EVA/src/bpe_dialog_new/vocab.txt')

In [127]:
config = T5Config(
    vocab_size=30000,
    # n_positions=self.n_positions,
    d_model=2048,
    d_ff=5120,
    d_kv=2048 // 32,
    num_layers=24,
    num_heads=32,
    relative_attention_num_buckets=32,
    dropout_rate=0.0,
    initializer_factor=1.0,
    eos_token_id=tokenizer.eod_id,
    bos_token_id=tokenizer.pad_id,
    pad_token_id=tokenizer.pad_id,
    decoder_start_token_id=tokenizer.pad_id,
    feed_forward_proj='gated-gelu',
    tie_word_embeddings=False
)

In [128]:
model = T5ForConditionalGeneration(config)

In [129]:
_ = model(input_ids=torch.LongTensor([[1]]), decoder_input_ids=torch.LongTensor([[1]]))

In [130]:
# transformers的T5是把QKV分开的
for k, v in model.state_dict().items():
    print(k, v.shape)

shared.weight torch.Size([30000, 2048])
encoder.embed_tokens.weight torch.Size([30000, 2048])
encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([2048, 2048])
encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([2048, 2048])
encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([2048, 2048])
encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([2048, 2048])
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 32])
encoder.block.0.layer.0.layer_norm.weight torch.Size([2048])
encoder.block.0.layer.1.DenseReluDense.wi_0.weight torch.Size([5120, 2048])
encoder.block.0.layer.1.DenseReluDense.wi_1.weight torch.Size([5120, 2048])
encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([2048, 5120])
encoder.block.0.layer.1.layer_norm.weight torch.Size([2048])
encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([2048, 2048])
encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([2048, 2048])
encoder.block.1.layer.0.SelfAttentio

In [8]:
import struct
import numpy as np
from collections import OrderedDict


def load_dtype(fp):
    v = struct.unpack("B", fp.read(1))[0]
    if v == 0:
        return np.int8
    elif v == 1:
        return np.float16
    elif v == 2:
        return np.float32
    else:
        raise TypeError("Unknown dtype %d" % v)

def load_string(fp):
    size = struct.unpack("I", fp.read(4))[0]
    v = fp.read(size)
    return v.decode("utf-8")

def load_tuple(fp):
    dim_tuple = struct.unpack("B", fp.read(1))[0]
    ret = []
    for _ in range(dim_tuple):
        ret.append(struct.unpack("I", fp.read(4))[0]) 
    return tuple(ret)

def load_parameter(fp):    
    shape = load_tuple(fp)
    value_size = struct.unpack("I", fp.read(4))[0]
    dtype = load_dtype(fp)
    value = fp.read(value_size)
    return shape, value, dtype

def load(fp, parent_name=''):
    num_parameters, num_sub_layers = struct.unpack("II", fp.read(8))
    parameters = []

    for _ in range(num_parameters):
        name = load_string(fp)
        shape, value, dtype = load_parameter(fp)
        parameters.append((parent_name + '.' + name, np.frombuffer(value, dtype).reshape(shape)))
    for _ in range(num_sub_layers):
        name = load_string(fp)
        parameters += load(fp, parent_name + '.' + name)
    return parameters

In [9]:
with open('../zhiyuan/eva/checkpoint.pt', 'rb') as fp:
    parameters = load(fp)

In [10]:
pindex = {x[0]: x[1] for x in parameters}

In [16]:
npara = {}
for name, value in tqdm(parameters):
    if '_scale' not in name:
        has_scale = name + '_scale'
        if has_scale in pindex:
            scale = pindex[has_scale]
            value = value.astype(np.float16) * scale
        npara[name] = value

100%|██████████| 704/704 [00:44<00:00, 15.65it/s]


In [90]:
def get_encoder(n):
    params = []
    for k, v in npara.items():
        if n == 0 and '.encoder_position_bias.embedding.weight' in k:
            params.append((
                'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight',
                v,
            ))
        if f'.encoder.{n}.' in k:
            if 'self_attention.w_project_qkv' in k:
                params.append((
                    f'encoder.block.{n}.layer.0.SelfAttention.q.weight',
                    v[0]
                ))
                params.append((
                    f'encoder.block.{n}.layer.0.SelfAttention.k.weight',
                    v[1]
                ))
                params.append((
                    f'encoder.block.{n}.layer.0.SelfAttention.v.weight',
                    v[2]
                ))
            if 'self_attention.w_out' in k:
                params.append((
                    f'encoder.block.{n}.layer.0.SelfAttention.o.weight',
                    v,
                ))
            if 'layer_nrom_before_self_attn.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.0.layer_norm.weight',
                    v,
                ))
            if 'dense_gelu_dense.wi_0.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.1.DenseReluDense.wi_0.weight',
                    v,
                ))
            if 'dense_gelu_dense.wi_1.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.1.DenseReluDense.wi_1.weight',
                    v,
                ))
            if 'dense_gelu_dense.wo.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.1.DenseReluDense.wo.weight',
                    v,
                ))
            if 'layer_nrom_before_ff.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.1.layer_norm.weight',
                    v,
                ))
    return params

In [185]:
def get_decoder(n):
    params = []
    for k, v in npara.items():
        if n == 0 and '.decoder_position_bias.embedding.weight' in k:
            params.append((
                'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight',
                v,
            ))
        if f'.decoder.{n}.' in k:
            if 'self_attention.w_project_qkv' in k:
                params.append((
                    f'decoder.block.{n}.layer.0.SelfAttention.q.weight',
                    v[0]
                ))
                params.append((
                    f'decoder.block.{n}.layer.0.SelfAttention.k.weight',
                    v[1]
                ))
                params.append((
                    f'decoder.block.{n}.layer.0.SelfAttention.v.weight',
                    v[2]
                ))
            if 'self_attention.w_out' in k:
                params.append((
                    f'decoder.block.{n}.layer.0.SelfAttention.o.weight',
                    v,
                ))
            if 'layer_nrom_before_self_attn.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.0.layer_norm.weight',
                    v,
                ))
                
            if '.cross_attention.w_project_q' in k:
                params.append((
                    f'decoder.block.{n}.layer.1.EncDecAttention.q.weight',
                    v,
                ))
            if '.cross_attention.w_out' in k:
                params.append((
                    f'decoder.block.{n}.layer.1.EncDecAttention.o.weight',
                    v,
                ))
            if 'layer_nrom_before_cross_attn.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.1.layer_norm.weight',
                    v,
                ))

            if 'dense_gelu_dense.wi_0.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.2.DenseReluDense.wi_0.weight',
                    v,
                ))
            if 'dense_gelu_dense.wi_1.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.2.DenseReluDense.wi_1.weight',
                    v,
                ))
            if 'dense_gelu_dense.wo.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.2.DenseReluDense.wo.weight',
                    v,
                ))
            if 'layer_nrom_before_ff.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.2.layer_norm.weight',
                    v,
                ))
#     params.append((
#         f'decoder.block.{n}.layer.1.EncDecAttention.k.weight',
#         npara[f'.encoder.{n}.self_attention.w_project_qkv'][1]
#     ))
#     params.append((
#         f'decoder.block.{n}.layer.1.EncDecAttention.v.weight',
#         npara[f'.encoder.{n}.self_attention.w_project_qkv'][2]
#     ))
    params.append((
        f'decoder.block.{n}.layer.1.EncDecAttention.k.weight',
        npara['.encoder_kv.w_project_kv'][n][0]
    ))
    params.append((
        f'decoder.block.{n}.layer.1.EncDecAttention.v.weight',
        npara['.encoder_kv.w_project_kv'][n][1]
    ))
    return params

In [186]:
new_state_dict = []
new_state_dict.append((
    'shared.weight',
    npara['.input_embedding.weight'],
))
new_state_dict.append((
    'encoder.embed_tokens.weight',
    npara['.input_embedding.weight'],
))

for i in range(24):
    new_state_dict += get_encoder(i)

new_state_dict.append((
    'encoder.final_layer_norm.weight',
    npara['.encoder_final_layer_nrom.weight'],
))

new_state_dict.append((
    'decoder.embed_tokens.weight',
    npara['.input_embedding.weight'],
))

for i in range(24):
    new_state_dict += get_decoder(i)

new_state_dict.append((
    'decoder.final_layer_norm.weight',
    npara['.decoder_final_layer_nrom.weight'],
))

new_state_dict.append((
    'lm_head.weight',
    npara['.lm_head.weight'],
))

In [187]:
len(new_state_dict)

560

In [188]:
len(model.state_dict())

560

In [189]:
model.load_state_dict({
    k: torch.from_numpy(v)
    for k, v in new_state_dict
})

<All keys matched successfully>

In [190]:
model = model.eval()

In [199]:
decoder_input = [tokenizer.get_sentinel_id(0)]
input_ids = torch.LongTensor([
    tokenizer.encode('你喜欢郭德纲吗？') + [tokenizer.sep_id, tokenizer.get_sentinel_id(0)]])

for i in range(10):
    decoder_input_ids = torch.LongTensor([decoder_input])
    outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
    next_tokens = np.argsort(outputs.logits[0, -1, :].detach().numpy())
    next_tokens = next_tokens[::-1]
    next_token = next_tokens[0]
    decoder_input.append(next_token)
    print(tokenizer.decode(decoder_input[1:]))

喜
喜欢
喜欢啊
喜欢啊,
喜欢啊,但
喜欢啊,但是
喜欢啊,但是我
喜欢啊,但是我不
喜欢啊,但是我不喜
喜欢啊,但是我不喜欢


In [196]:
input_ids = torch.LongTensor([
    tokenizer.encode('你喜欢郭德纲吗') + [tokenizer.sep_id, tokenizer.get_sentinel_id(0)]]
)
out = model.generate(
    input_ids,
    decoder_start_token_id=tokenizer.get_sentinel_id(0),
    eos_token_id=tokenizer.sep_id,
    do_sample=True,
    top_p=0.85,
    top_k=10
)
out_text = tokenizer.decode(out.numpy()[0].tolist()[1:-1])
print(out_text)

喜欢啊


In [229]:
sentence = []

while True:
    text = input('>>>')
    if text == 'quit' or len(text.strip()) <= 0:
        break
    sentence.append(text)

    input_ids = []
    for x in sentence:
        input_ids += tokenizer.encode(x) + [tokenizer.sep_id]
    input_ids += [tokenizer.get_sentinel_id(0)]
    input_ids = torch.LongTensor([input_ids])
    
    out = model.generate(
        input_ids,
        decoder_start_token_id=tokenizer.get_sentinel_id(0),
        eos_token_id=tokenizer.sep_id,
        do_sample=True,
        top_p=0.85,
        top_k=10
    )
    out_text = tokenizer.decode(out.numpy()[0].tolist()[1:-1])
    print(out_text)
    sentence.append(out_text)

>>>你好
您好,请问有什么问题小妹可以帮您处理或解决呢?
>>>你能做啥啊
亲亲,您方便简单描述下您的问题吗?
>>>你喜欢郭德纲吗
嗯嗯亲亲您方便简单描述下您的问题吗?
>>>你喜欢郭德纲吗
喜欢呀
>>>你听过他的什么？
亲亲是在哪啊
>>>


In [None]:
# model.save_pretrained('./torch_eva')

In [115]:
# input_text = '你好'
# decode_text = '你也好啊'
# input_ids = tokenizer.encode(input_text) + [tokenizer.sep_id, tokenizer.get_sentinel_id(0)]
# decoder_input_ids = [tokenizer.get_sentinel_id(0)] + tokenizer.encode(decode_text)
# labels = tokenizer.encode(decode_text) + [tokenizer.sep_id]
# out = model(
#     torch.LongTensor(input_ids).unsqueeze(0),
#     decoder_input_ids=torch.LongTensor(decoder_input_ids).unsqueeze(0),
#     labels=torch.LongTensor(labels).unsqueeze(0))
# out.loss