In [1]:
import re
from collections import OrderedDict

import torch
import numpy as np
import tensorflow as tf
from transformers import TFT5EncoderModel, TFT5Model
from transformers import T5Config

from tokenization_enc_dec import EncDecTokenizer

In [2]:
tokenizer = EncDecTokenizer('./EVA/src/bpe_dialog_new/vocab.txt')

In [3]:
!ls ../eva-ckpt/222500

mp_rank_00_model_states.pt


In [4]:
state_dict = torch.load('../eva-ckpt/222500/mp_rank_00_model_states.pt', map_location='cpu')['module']

In [5]:
# 注意encoder和decoder的blocks.0和其他的1,2,3...23不一样，0多了relative_attention_bias
for k, v in state_dict.items():
    print(k, v.shape)

word_embeds.weight torch.Size([30000, 2048])
lm_head.weight torch.Size([30000, 2048])
encoder.word_embeds.weight torch.Size([30000, 2048])
encoder.final_layernorm.weight torch.Size([2048])
encoder.blocks.0.self_attn.self_attn.project.weight torch.Size([6144, 2048])
encoder.blocks.0.self_attn.self_attn.relative_attention_bias.weight torch.Size([32, 32])
encoder.blocks.0.self_attn.self_attn.dense.weight torch.Size([2048, 2048])
encoder.blocks.0.self_attn.layer_norm.weight torch.Size([2048])
encoder.blocks.0.ff.dense_relu_dense.wi_0.weight torch.Size([5120, 2048])
encoder.blocks.0.ff.dense_relu_dense.wi_1.weight torch.Size([5120, 2048])
encoder.blocks.0.ff.dense_relu_dense.wo.weight torch.Size([2048, 5120])
encoder.blocks.0.ff.layer_norm.weight torch.Size([2048])
encoder.blocks.1.self_attn.self_attn.project.weight torch.Size([6144, 2048])
encoder.blocks.1.self_attn.self_attn.dense.weight torch.Size([2048, 2048])
encoder.blocks.1.self_attn.layer_norm.weight torch.Size([2048])
encoder.block

In [6]:
def get_weight(name):
    return state_dict[name].numpy()


def get_block_weight(n, t='encoder', name=False, dim=2048):
    weights = []
    for k, v in state_dict.items():
        if t in k and f'blocks.{n}.' in k:
            # pytorch和tensorflow版本的weights是矩阵转置的
            w = v.numpy()
            if 'self_attn.project' in k:
                w0, w1, w2 = w[:dim, :], w[dim:dim*2, :], w[dim*2:, :]
                w0 = np.transpose(w0)
                w1 = np.transpose(w1)
                w2 = np.transpose(w2)
                weights.append((k, w0))
                weights.append((k, w1))
                weights.append((k, w2))
            elif 'cross_attn.project_q' in k:
                w = np.transpose(w)
                weights.append((k, w))
            elif 'cross_attn.project_kv' in k:
                w0, w1 = w[:dim, :], w[dim:, :]
                w0 = np.transpose(w0)
                w1 = np.transpose(w1)
                weights.append((k, w0))
                weights.append((k, w1))
            else:
                if 'dense' in k:
                    w = np.transpose(w)
                weights.append((k, w))
    if 'relative_attention_bias' in weights[3][0]:
        weights = weights[3:4] + weights[:3] + weights[4:]
    if not name:
        weights = [x[1] for x in weights]
    return weights

In [7]:
config = T5Config(
    vocab_size=30000,
    # n_positions=self.n_positions,
    d_model=2048,
    d_ff=5120,
    d_kv=2048 // 32,
    num_layers=24,
    num_heads=32,
    relative_attention_num_buckets=32,
    dropout_rate=0.0,
    initializer_factor=1.0,
    eos_token_id=tokenizer.eod_id,
    bos_token_id=tokenizer.pad_id,
    pad_token_id=tokenizer.pad_id,
    decoder_start_token_id=tokenizer.pad_id,
    feed_forward_proj='gated-gelu',
    tie_word_embeddings=False
)

In [8]:
model = TFT5EncoderModel(config)

In [9]:
out = model(input_ids=tf.constant([[1]]))

In [10]:
out.keys()

odict_keys(['last_hidden_state'])

In [11]:
# transformers的T5是把QKV分开的
for k in model.variables:
    print(k.name, k.shape)

shared/shared/weight:0 (30000, 2048)
tf_t5encoder_model/encoder/block_._0/layer_._0/SelfAttention/relative_attention_bias/embeddings:0 (32, 32)
tf_t5encoder_model/encoder/block_._0/layer_._0/SelfAttention/q/kernel:0 (2048, 2048)
tf_t5encoder_model/encoder/block_._0/layer_._0/SelfAttention/k/kernel:0 (2048, 2048)
tf_t5encoder_model/encoder/block_._0/layer_._0/SelfAttention/v/kernel:0 (2048, 2048)
tf_t5encoder_model/encoder/block_._0/layer_._0/SelfAttention/o/kernel:0 (2048, 2048)
tf_t5encoder_model/encoder/block_._0/layer_._0/layer_norm/weight:0 (2048,)
tf_t5encoder_model/encoder/block_._0/layer_._1/DenseReluDense/wi_0/kernel:0 (2048, 5120)
tf_t5encoder_model/encoder/block_._0/layer_._1/DenseReluDense/wi_1/kernel:0 (2048, 5120)
tf_t5encoder_model/encoder/block_._0/layer_._1/DenseReluDense/wo/kernel:0 (5120, 2048)
tf_t5encoder_model/encoder/block_._0/layer_._1/layer_norm/weight:0 (2048,)
tf_t5encoder_model/encoder/block_._1/layer_._0/SelfAttention/q/kernel:0 (2048, 2048)
tf_t5encoder_mod

In [12]:
for x in get_block_weight(0, t='encoder', name=True):
    print(x[0], x[1].shape)

encoder.blocks.0.self_attn.self_attn.relative_attention_bias.weight (32, 32)
encoder.blocks.0.self_attn.self_attn.project.weight (2048, 2048)
encoder.blocks.0.self_attn.self_attn.project.weight (2048, 2048)
encoder.blocks.0.self_attn.self_attn.project.weight (2048, 2048)
encoder.blocks.0.self_attn.self_attn.dense.weight (2048, 2048)
encoder.blocks.0.self_attn.layer_norm.weight (2048,)
encoder.blocks.0.ff.dense_relu_dense.wi_0.weight (2048, 5120)
encoder.blocks.0.ff.dense_relu_dense.wi_1.weight (2048, 5120)
encoder.blocks.0.ff.dense_relu_dense.wo.weight (5120, 2048)
encoder.blocks.0.ff.layer_norm.weight (2048,)


In [13]:
for x in get_block_weight(1, t='encoder', name=True):
    print(x[0], x[1].shape)

encoder.blocks.1.self_attn.self_attn.project.weight (2048, 2048)
encoder.blocks.1.self_attn.self_attn.project.weight (2048, 2048)
encoder.blocks.1.self_attn.self_attn.project.weight (2048, 2048)
encoder.blocks.1.self_attn.self_attn.dense.weight (2048, 2048)
encoder.blocks.1.self_attn.layer_norm.weight (2048,)
encoder.blocks.1.ff.dense_relu_dense.wi_0.weight (2048, 5120)
encoder.blocks.1.ff.dense_relu_dense.wi_1.weight (2048, 5120)
encoder.blocks.1.ff.dense_relu_dense.wo.weight (5120, 2048)
encoder.blocks.1.ff.layer_norm.weight (2048,)


In [14]:
model_new_weights = [get_weight('word_embeds.weight')]
for i in range(24):
    model_new_weights += get_block_weight(i, t='encoder')
model_new_weights += [get_weight('encoder.final_layernorm.weight')]

In [15]:
len(model.variables)

219

In [16]:
len(model_new_weights)

219

In [17]:
model.set_weights(model_new_weights)

In [18]:
input_ids = tf.constant([tokenizer.encode('你好')])
out = model(input_ids)

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 1.767 seconds.
Prefix dict has been built successfully.


In [21]:
out.keys()

odict_keys(['last_hidden_state'])

In [19]:
model.save_pretrained('./tf_eva_encoder')

In [20]:
!du -sh './tf_eva_encoder'

4.6G	./tf_eva_encoder
