In [1]:
import re
from collections import OrderedDict

import torch
import numpy as np
from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Model
from transformers import T5Config

from tokenization_enc_dec import EncDecTokenizer

In [2]:
tokenizer = EncDecTokenizer('./EVA/src/bpe_dialog_new/vocab.txt')

In [3]:
!ls ../eva-ckpt/222500

mp_rank_00_model_states.pt


In [4]:
state_dict = torch.load('../eva-ckpt/222500/mp_rank_00_model_states.pt', map_location='cpu')['module']

In [5]:
# 注意encoder和decoder的blocks.0和其他的1,2,3...23不一样，0多了relative_attention_bias
for k, v in state_dict.items():
    print(k, v.shape)

word_embeds.weight torch.Size([30000, 2048])
lm_head.weight torch.Size([30000, 2048])
encoder.word_embeds.weight torch.Size([30000, 2048])
encoder.final_layernorm.weight torch.Size([2048])
encoder.blocks.0.self_attn.self_attn.project.weight torch.Size([6144, 2048])
encoder.blocks.0.self_attn.self_attn.relative_attention_bias.weight torch.Size([32, 32])
encoder.blocks.0.self_attn.self_attn.dense.weight torch.Size([2048, 2048])
encoder.blocks.0.self_attn.layer_norm.weight torch.Size([2048])
encoder.blocks.0.ff.dense_relu_dense.wi_0.weight torch.Size([5120, 2048])
encoder.blocks.0.ff.dense_relu_dense.wi_1.weight torch.Size([5120, 2048])
encoder.blocks.0.ff.dense_relu_dense.wo.weight torch.Size([2048, 5120])
encoder.blocks.0.ff.layer_norm.weight torch.Size([2048])
encoder.blocks.1.self_attn.self_attn.project.weight torch.Size([6144, 2048])
encoder.blocks.1.self_attn.self_attn.dense.weight torch.Size([2048, 2048])
encoder.blocks.1.self_attn.layer_norm.weight torch.Size([2048])
encoder.block

In [158]:
def get_weight(name):
    return state_dict[name].numpy()

encoder_names0 = [
    'encoder.block.{}.layer.0.SelfAttention.q.weight',
    'encoder.block.{}.layer.0.SelfAttention.k.weight',
    'encoder.block.{}.layer.0.SelfAttention.v.weight',
    'encoder.block.{}.layer.0.SelfAttention.o.weight',
    'encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight',
    'encoder.block.{}.layer.0.layer_norm.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_0.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_1.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wo.weight',
    'encoder.block.{}.layer.1.layer_norm.weight',
]

decoder_names0 = [
    'decoder.block.{}.layer.0.SelfAttention.q.weight',
    'decoder.block.{}.layer.0.SelfAttention.k.weight',
    'decoder.block.{}.layer.0.SelfAttention.v.weight',
    'decoder.block.{}.layer.0.SelfAttention.o.weight',
    'decoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight',
    'decoder.block.{}.layer.0.layer_norm.weight',
    'decoder.block.{}.layer.1.EncDecAttention.q.weight',
    'decoder.block.{}.layer.1.EncDecAttention.k.weight',
    'decoder.block.{}.layer.1.EncDecAttention.v.weight',
    'decoder.block.{}.layer.1.EncDecAttention.o.weight',
    'decoder.block.{}.layer.1.layer_norm.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_0.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_1.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wo.weight',
    'decoder.block.{}.layer.2.layer_norm.weight',
]

encoder_names = [
    'encoder.block.{}.layer.0.SelfAttention.q.weight',
    'encoder.block.{}.layer.0.SelfAttention.k.weight',
    'encoder.block.{}.layer.0.SelfAttention.v.weight',
    'encoder.block.{}.layer.0.SelfAttention.o.weight',
    'encoder.block.{}.layer.0.layer_norm.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_0.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_1.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wo.weight',
    'encoder.block.{}.layer.1.layer_norm.weight',
]

decoder_names = [
    'decoder.block.{}.layer.0.SelfAttention.q.weight',
    'decoder.block.{}.layer.0.SelfAttention.k.weight',
    'decoder.block.{}.layer.0.SelfAttention.v.weight',
    'decoder.block.{}.layer.0.SelfAttention.o.weight',
    'decoder.block.{}.layer.0.layer_norm.weight',
    'decoder.block.{}.layer.1.EncDecAttention.q.weight',
    'decoder.block.{}.layer.1.EncDecAttention.k.weight',
    'decoder.block.{}.layer.1.EncDecAttention.v.weight',
    'decoder.block.{}.layer.1.EncDecAttention.o.weight',
    'decoder.block.{}.layer.1.layer_norm.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_0.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_1.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wo.weight',
    'decoder.block.{}.layer.2.layer_norm.weight',
]

def get_block_weight(n, t='encoder', dim=2048):
    weights = []
    for k, v in state_dict.items():
        if t in k and f'blocks.{n}.' in k:
            # pytorch和tensorflow版本的weights是矩阵转置的
            w = v.numpy()
            if 'self_attn.project' in k:
                w0, w1, w2 = w[:dim, :], w[dim:dim*2, :], w[dim*2:, :]
#                 w0 = np.transpose(w0)
#                 w1 = np.transpose(w1)
#                 w2 = np.transpose(w2)
                weights.append((k, w0))
                weights.append((k, w1))
                weights.append((k, w2))
            elif 'cross_attn.project_q' in k:
#                 w = np.transpose(w)
                weights.append((k, w))
            elif 'cross_attn.project_kv' in k:
                w0, w1 = w[:dim, :], w[dim:, :]
#                 w0 = np.transpose(w0)
#                 w1 = np.transpose(w1)
                weights.append((k, w0))
                weights.append((k, w1))
            else:
#                 if 'dense' in k:
#                     w = np.transpose(w)
                weights.append((k, w))
    if 'relative_attention_bias' in weights[3][0]:
        weights[3], weights[4] = weights[4], weights[3]
    weights = [x[1] for x in weights]
    if 'encoder' == t:
        weights_dict = OrderedDict()
        for k, v in zip(encoder_names0 if n == 0 else encoder_names, weights):
            weights_dict[k.format(n)] = v
        weights = weights_dict
    else:
        weights_dict = OrderedDict()
        for k, v in zip(decoder_names0 if n == 0 else decoder_names, weights):
            weights_dict[k.format(n)] = v
        weights = weights_dict
    return weights

In [7]:
config = T5Config(
    vocab_size=30000,
    # n_positions=self.n_positions,
    d_model=2048,
    d_ff=5120,
    d_kv=2048 // 32,
    num_layers=24,
    num_heads=32,
    relative_attention_num_buckets=32,
    dropout_rate=0.0,
    initializer_factor=1.0,
    eos_token_id=tokenizer.eod_id,
    bos_token_id=tokenizer.pad_id,
    pad_token_id=tokenizer.pad_id,
    decoder_start_token_id=tokenizer.pad_id,
    feed_forward_proj='gated-gelu',
    tie_word_embeddings=False
)

In [8]:
model = T5ForConditionalGeneration(config)

In [9]:
_ = model(input_ids=torch.LongTensor([[1]]), decoder_input_ids=torch.LongTensor([[1]]))

<generator object Module.parameters at 0x7eff84acb580>

In [122]:
# transformers的T5是把QKV分开的
for k, v in model.state_dict().items():
    print(k, v.shape)

shared.weight torch.Size([30000, 2048])
encoder.embed_tokens.weight torch.Size([30000, 2048])
encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([2048, 2048])
encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([2048, 2048])
encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([2048, 2048])
encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([2048, 2048])
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 32])
encoder.block.0.layer.0.layer_norm.weight torch.Size([2048])
encoder.block.0.layer.1.DenseReluDense.wi_0.weight torch.Size([5120, 2048])
encoder.block.0.layer.1.DenseReluDense.wi_1.weight torch.Size([5120, 2048])
encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([2048, 5120])
encoder.block.0.layer.1.layer_norm.weight torch.Size([2048])
encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([2048, 2048])
encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([2048, 2048])
encoder.block.1.layer.0.SelfAttentio

In [123]:
for x in get_block_weight(0, t='encoder').items():
    print(x[0], x[1].shape)

encoder.block.0.layer.0.SelfAttention.q.weight (2048, 2048)
encoder.block.0.layer.0.SelfAttention.k.weight (2048, 2048)
encoder.block.0.layer.0.SelfAttention.v.weight (2048, 2048)
encoder.block.0.layer.0.SelfAttention.o.weight (2048, 2048)
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight (32, 32)
encoder.block.0.layer.0.layer_norm.weight (2048,)
encoder.block.0.layer.1.DenseReluDense.wi_0.weight (5120, 2048)
encoder.block.0.layer.1.DenseReluDense.wi_1.weight (5120, 2048)
encoder.block.0.layer.1.DenseReluDense.wo.weight (2048, 5120)
encoder.block.0.layer.1.layer_norm.weight (2048,)


In [124]:
for x in get_block_weight(1, t='encoder').items():
    print(x[0], x[1].shape)

encoder.block.1.layer.0.SelfAttention.q.weight (2048, 2048)
encoder.block.1.layer.0.SelfAttention.k.weight (2048, 2048)
encoder.block.1.layer.0.SelfAttention.v.weight (2048, 2048)
encoder.block.1.layer.0.SelfAttention.o.weight (2048, 2048)
encoder.block.1.layer.0.layer_norm.weight (2048,)
encoder.block.1.layer.1.DenseReluDense.wi_0.weight (5120, 2048)
encoder.block.1.layer.1.DenseReluDense.wi_1.weight (5120, 2048)
encoder.block.1.layer.1.DenseReluDense.wo.weight (2048, 5120)
encoder.block.1.layer.1.layer_norm.weight (2048,)


In [125]:
for x in get_block_weight(0, t='decoder').items():
    print(x[0], x[1].shape)

decoder.block.0.layer.0.SelfAttention.q.weight (2048, 2048)
decoder.block.0.layer.0.SelfAttention.k.weight (2048, 2048)
decoder.block.0.layer.0.SelfAttention.v.weight (2048, 2048)
decoder.block.0.layer.0.SelfAttention.o.weight (2048, 2048)
decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight (32, 32)
decoder.block.0.layer.0.layer_norm.weight (2048,)
decoder.block.0.layer.1.EncDecAttention.q.weight (2048, 2048)
decoder.block.0.layer.1.EncDecAttention.k.weight (2048, 2048)
decoder.block.0.layer.1.EncDecAttention.v.weight (2048, 2048)
decoder.block.0.layer.1.EncDecAttention.o.weight (2048, 2048)
decoder.block.0.layer.1.layer_norm.weight (2048,)
decoder.block.0.layer.2.DenseReluDense.wi_0.weight (5120, 2048)
decoder.block.0.layer.2.DenseReluDense.wi_1.weight (5120, 2048)
decoder.block.0.layer.2.DenseReluDense.wo.weight (2048, 5120)
decoder.block.0.layer.2.layer_norm.weight (2048,)


In [126]:
for x in get_block_weight(1, t='decoder').items():
    print(x[0], x[1].shape)

decoder.block.1.layer.0.SelfAttention.q.weight (2048, 2048)
decoder.block.1.layer.0.SelfAttention.k.weight (2048, 2048)
decoder.block.1.layer.0.SelfAttention.v.weight (2048, 2048)
decoder.block.1.layer.0.SelfAttention.o.weight (2048, 2048)
decoder.block.1.layer.0.layer_norm.weight (2048,)
decoder.block.1.layer.1.EncDecAttention.q.weight (2048, 2048)
decoder.block.1.layer.1.EncDecAttention.k.weight (2048, 2048)
decoder.block.1.layer.1.EncDecAttention.v.weight (2048, 2048)
decoder.block.1.layer.1.EncDecAttention.o.weight (2048, 2048)
decoder.block.1.layer.1.layer_norm.weight (2048,)
decoder.block.1.layer.2.DenseReluDense.wi_0.weight (5120, 2048)
decoder.block.1.layer.2.DenseReluDense.wi_1.weight (5120, 2048)
decoder.block.1.layer.2.DenseReluDense.wo.weight (2048, 5120)
decoder.block.1.layer.2.layer_norm.weight (2048,)


In [159]:
# shared.weight torch.Size([30000, 2048])
# encoder.embed_tokens.weight torch.Size([30000, 2048])
model_new_weights = OrderedDict()
model_new_weights['shared.weight'] = get_weight('word_embeds.weight')
model_new_weights['encoder.embed_tokens.weight'] = get_weight('encoder.word_embeds.weight')
for i in range(24):
    for k, v in get_block_weight(i, t='encoder').items():
        model_new_weights[k] = v

model_new_weights['encoder.final_layer_norm.weight'] = get_weight('encoder.final_layernorm.weight')

for i in range(24):
    for k, v in get_block_weight(i, t='decoder').items():
        model_new_weights[k] = v

model_new_weights['decoder.final_layer_norm.weight'] = get_weight('decoder.final_layernorm.weight')

model_new_weights['decoder.embed_tokens.weight'] = get_weight('decoder.word_embeds.weight')
model_new_weights['lm_head.weight'] = get_weight('lm_head.weight')

In [160]:
len(model_new_weights)

560

In [161]:
len(model.state_dict())

560

In [162]:
model.load_state_dict({k: torch.from_numpy(v) for k, v in model_new_weights.items()})

<All keys matched successfully>

In [163]:
decoder_input = [tokenizer.get_sentinel_id(0)]
input_ids = torch.LongTensor([tokenizer.encode('你好啊') + [tokenizer.sep_id, tokenizer.get_sentinel_id(0)]])

for i in range(10):
    decoder_input_ids = torch.LongTensor([decoder_input])
    outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
    next_token = np.argsort(outputs.logits[0, -1, :].detach().numpy())[-1]
    decoder_input.append(next_token)
    print(tokenizer.decode(decoder_input[1:]))

你好
你好,
你好,我
你好,我是
你好,我是你
你好,我是你的
你好,我是你的粉丝
你好,我是你的粉丝<sep>
你好,我是你的粉丝<sep><sep>
你好,我是你的粉丝<sep><sep>,


In [164]:
model.save_pretrained('./torch_eva')

In [165]:
!du -sh torch_eva

11G	torch_eva


In [170]:
sentence = []

while True:
    text = input('>>>')
    if text == 'quit' or len(text.strip()) <= 0:
        break
    sentence.append(text)

    input_ids = []
    for x in sentence:
        input_ids += tokenizer.encode(x) + [tokenizer.sep_id]
    input_ids += [tokenizer.get_sentinel_id(0)]
    print(input_ids)
    input_ids = torch.LongTensor([input_ids])
    
    out = model.generate(
        input_ids,
        decoder_start_token_id=tokenizer.get_sentinel_id(0),
        eos_token_id=tokenizer.sep_id,
        do_sample=True,
        top_p=0.95,
        top_k=50
    )
    out_text = tokenizer.decode(out.numpy()[0].tolist()[1:-1])
    print(out_text)
    sentence.append(out_text)

>>>你好
[5503, 4, 29810]
请问您是咨询之前的问题还是有其他的问题需要处理呢?
>>>


In [171]:
input_text = '你好'
decode_text = '你也好啊'

In [174]:
input_ids = tokenizer.encode(input_text) + [tokenizer.sep_id, tokenizer.get_sentinel_id(0)]

In [177]:
decoder_input_ids = [tokenizer.get_sentinel_id(0)] + tokenizer.encode(decode_text)
labels = tokenizer.encode(decode_text) + [tokenizer.sep_id]

In [181]:
out = model(
    torch.LongTensor(input_ids).unsqueeze(0),
    decoder_input_ids=torch.LongTensor(decoder_input_ids).unsqueeze(0),
    labels=torch.LongTensor(labels).unsqueeze(0))

In [182]:
out.loss

tensor(4.4332, grad_fn=<NllLossBackward>)