In [1]:
from collections import OrderedDict

import torch
import tensorflow as tf
import numpy as np

from transformers import T5EncoderModel, TFT5Model
from transformers import T5Config

In [2]:
from tokenization_enc_dec import EncDecTokenizer
tokenizer = EncDecTokenizer('./vocab.txt')

In [3]:
config = T5Config(
    vocab_size=26240,
#     n_positions=self.n_positions,
    d_model=4096,
    d_ff=10240,
    d_kv=4096 // 64,
    num_layers=24,
    num_heads=64,
    relative_attention_num_buckets=32,
    dropout_rate=0.0,
    initializer_factor=1.0,
    eos_token_id=tokenizer.eod_id,
    bos_token_id=tokenizer.pad_id,
    pad_token_id=tokenizer.pad_id,
    decoder_start_token_id=tokenizer.pad_id,
    feed_forward_proj='gated-gelu',
    tie_word_embeddings=False
)

In [4]:
model = T5EncoderModel(config)

In [5]:
out = model(input_ids=torch.LongTensor([[1]]))

In [6]:
out.keys()

odict_keys(['last_hidden_state'])

In [7]:
len(list(model.parameters()))

219

In [8]:
def get_weight(name):
    return state_dict[name].numpy()

encoder_names0 = [
    'encoder.block.{}.layer.0.SelfAttention.q.weight',
    'encoder.block.{}.layer.0.SelfAttention.k.weight',
    'encoder.block.{}.layer.0.SelfAttention.v.weight',
    'encoder.block.{}.layer.0.SelfAttention.o.weight',
    'encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight',
    'encoder.block.{}.layer.0.layer_norm.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_0.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_1.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wo.weight',
    'encoder.block.{}.layer.1.layer_norm.weight',
]

decoder_names0 = [
    'decoder.block.{}.layer.0.SelfAttention.q.weight',
    'decoder.block.{}.layer.0.SelfAttention.k.weight',
    'decoder.block.{}.layer.0.SelfAttention.v.weight',
    'decoder.block.{}.layer.0.SelfAttention.o.weight',
    'decoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight',
    'decoder.block.{}.layer.0.layer_norm.weight',
    'decoder.block.{}.layer.1.EncDecAttention.q.weight',
    'decoder.block.{}.layer.1.EncDecAttention.k.weight',
    'decoder.block.{}.layer.1.EncDecAttention.v.weight',
    'decoder.block.{}.layer.1.EncDecAttention.o.weight',
    'decoder.block.{}.layer.1.layer_norm.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_0.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_1.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wo.weight',
    'decoder.block.{}.layer.2.layer_norm.weight',
]

encoder_names = [
    'encoder.block.{}.layer.0.SelfAttention.q.weight',
    'encoder.block.{}.layer.0.SelfAttention.k.weight',
    'encoder.block.{}.layer.0.SelfAttention.v.weight',
    'encoder.block.{}.layer.0.SelfAttention.o.weight',
    'encoder.block.{}.layer.0.layer_norm.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_0.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_1.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wo.weight',
    'encoder.block.{}.layer.1.layer_norm.weight',
]

decoder_names = [
    'decoder.block.{}.layer.0.SelfAttention.q.weight',
    'decoder.block.{}.layer.0.SelfAttention.k.weight',
    'decoder.block.{}.layer.0.SelfAttention.v.weight',
    'decoder.block.{}.layer.0.SelfAttention.o.weight',
    'decoder.block.{}.layer.0.layer_norm.weight',
    'decoder.block.{}.layer.1.EncDecAttention.q.weight',
    'decoder.block.{}.layer.1.EncDecAttention.k.weight',
    'decoder.block.{}.layer.1.EncDecAttention.v.weight',
    'decoder.block.{}.layer.1.EncDecAttention.o.weight',
    'decoder.block.{}.layer.1.layer_norm.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_0.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_1.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wo.weight',
    'decoder.block.{}.layer.2.layer_norm.weight',
]

def get_block_weight(n, t='encoder', dim=4096):
    weights = []
    for k, v in state_dict.items():
        if t in k and f'blocks.{n}.' in k:
            # pytorch和tensorflow版本的weights是矩阵转置的
            w = v.numpy()
            if 'self_attn.project' in k:
                w0, w1, w2 = w[:dim, :], w[dim:dim*2, :], w[dim*2:, :]
#                 w0 = np.transpose(w0)
#                 w1 = np.transpose(w1)
#                 w2 = np.transpose(w2)
                weights.append((k, w0))
                weights.append((k, w1))
                weights.append((k, w2))
            elif 'cross_attn.project_q' in k:
#                 w = np.transpose(w)
                weights.append((k, w))
            elif 'cross_attn.project_kv' in k:
                w0, w1 = w[:dim, :], w[dim:, :]
#                 w0 = np.transpose(w0)
#                 w1 = np.transpose(w1)
                weights.append((k, w0))
                weights.append((k, w1))
            else:
#                 if 'dense' in k:
#                     w = np.transpose(w)
                weights.append((k, w))
    if 'relative_attention_bias' in weights[3][0]:
        weights[3], weights[4] = weights[4], weights[3]
    weights = [x[1] for x in weights]
    if 'encoder' == t:
        weights_dict = OrderedDict()
        for k, v in zip(encoder_names0 if n == 0 else encoder_names, weights):
            weights_dict[k.format(n)] = v
        weights = weights_dict
    else:
        weights_dict = OrderedDict()
        for k, v in zip(decoder_names0 if n == 0 else decoder_names, weights):
            weights_dict[k.format(n)] = v
        weights = weights_dict
    return weights

In [9]:
state_dict = torch.load('../converted.zip')

In [10]:
model_new_weights = OrderedDict()
model_new_weights['shared.weight'] = get_weight('word_embeds.weight')
model_new_weights['encoder.embed_tokens.weight'] = get_weight('encoder.word_embeds.weight')
for i in range(24):
    for k, v in get_block_weight(i, t='encoder').items():
        model_new_weights[k] = v

model_new_weights['encoder.final_layer_norm.weight'] = get_weight('encoder.final_layernorm.weight')

In [11]:
len(model_new_weights)

220

In [12]:
set(model.state_dict().keys()) - set(model_new_weights.keys())

set()

In [13]:
set(model_new_weights.keys()) - set(model.state_dict().keys())

set()

In [14]:
for k, v in state_dict.items():
    print(k, v.shape)

word_embeds.weight torch.Size([26240, 4096])
lm_head.weight torch.Size([26240, 4096])
encoder.word_embeds.weight torch.Size([26240, 4096])
encoder.final_layernorm.weight torch.Size([4096])
encoder.blocks.0.self_attn.self_attn.project.weight torch.Size([12288, 4096])
encoder.blocks.0.self_attn.self_attn.relative_attention_bias.weight torch.Size([32, 64])
encoder.blocks.0.self_attn.self_attn.dense.weight torch.Size([4096, 4096])
encoder.blocks.0.self_attn.layer_norm.weight torch.Size([4096])
encoder.blocks.0.ff.dense_relu_dense.wi_0.weight torch.Size([10240, 4096])
encoder.blocks.0.ff.dense_relu_dense.wi_1.weight torch.Size([10240, 4096])
encoder.blocks.0.ff.dense_relu_dense.wo.weight torch.Size([4096, 10240])
encoder.blocks.0.ff.layer_norm.weight torch.Size([4096])
encoder.blocks.1.self_attn.self_attn.project.weight torch.Size([12288, 4096])
encoder.blocks.1.self_attn.self_attn.dense.weight torch.Size([4096, 4096])
encoder.blocks.1.self_attn.layer_norm.weight torch.Size([4096])
encoder.

In [15]:
for k, v in model.state_dict().items():
    print(k, v.shape)

shared.weight torch.Size([26240, 4096])
encoder.embed_tokens.weight torch.Size([26240, 4096])
encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([4096, 4096])
encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([4096, 4096])
encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([4096, 4096])
encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([4096, 4096])
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 64])
encoder.block.0.layer.0.layer_norm.weight torch.Size([4096])
encoder.block.0.layer.1.DenseReluDense.wi_0.weight torch.Size([10240, 4096])
encoder.block.0.layer.1.DenseReluDense.wi_1.weight torch.Size([10240, 4096])
encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([4096, 10240])
encoder.block.0.layer.1.layer_norm.weight torch.Size([4096])
encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([4096, 4096])
encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([4096, 4096])
encoder.block.1.layer.0.SelfAtten

In [16]:
for k, v in model_new_weights.items():
    print(k, v.shape)

shared.weight (26240, 4096)
encoder.embed_tokens.weight (26240, 4096)
encoder.block.0.layer.0.SelfAttention.q.weight (4096, 4096)
encoder.block.0.layer.0.SelfAttention.k.weight (4096, 4096)
encoder.block.0.layer.0.SelfAttention.v.weight (4096, 4096)
encoder.block.0.layer.0.SelfAttention.o.weight (4096, 4096)
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight (32, 64)
encoder.block.0.layer.0.layer_norm.weight (4096,)
encoder.block.0.layer.1.DenseReluDense.wi_0.weight (10240, 4096)
encoder.block.0.layer.1.DenseReluDense.wi_1.weight (10240, 4096)
encoder.block.0.layer.1.DenseReluDense.wo.weight (4096, 10240)
encoder.block.0.layer.1.layer_norm.weight (4096,)
encoder.block.1.layer.0.SelfAttention.q.weight (4096, 4096)
encoder.block.1.layer.0.SelfAttention.k.weight (4096, 4096)
encoder.block.1.layer.0.SelfAttention.v.weight (4096, 4096)
encoder.block.1.layer.0.SelfAttention.o.weight (4096, 4096)
encoder.block.1.layer.0.layer_norm.weight (4096,)
encoder.block.1.layer.1.Dense

In [17]:
# assert len(model_new_weights) == len(model.variables)

In [18]:
model.load_state_dict({k: torch.from_numpy(v) for k, v in model_new_weights.items()})

<All keys matched successfully>

In [19]:
input_text = '''当地时间9月6日是美国劳工节，但就在这一天，上千万美国劳动者却陷入新的困境。因为美国政府为疫情期间失业者提供的主要救助同日到期，而且白宫表示没有进一步延长救助的计划。
在德尔塔变异株已把美国推入新一轮疫情的背景下，失业救济的突然“断供”意味着有上千万美国人将全部或部分失去他们的生活来源。'''
input_ids = torch.LongTensor([tokenizer.encode(input_text)])

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.548 seconds.
Prefix dict has been built successfully.


In [20]:
out = model(
    input_ids
)

In [23]:
out.keys()

odict_keys(['last_hidden_state'])

In [37]:
out.last_hidden_state.shape

torch.Size([1, 102, 4096])

In [38]:
# !rm -rf onnx

In [39]:
!mkdir -p onnx

In [40]:
torch.onnx.export(model,               # model being run
                  input_ids,                         # model input (or a tuple for multiple inputs)
                  "onnx/cpm_2_0_encoder.onnx",   # where to save the model (can be a file or file-like object)
                  # export_params=True,        # store the trained parameter weights inside the model file
                  # opset_version=12,          # the ONNX version to export the model to
                  verbose=True,
                  use_external_data_format=True)  # file size > 2G

graph(%0 : Long(1:102, 102:1),
      %shared.weight : Float(26240:4096, 4096:1),
      %encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight : Float(32:64, 64:1),
      %encoder.block.0.layer.0.layer_norm.weight : Float(4096:1),
      %encoder.block.0.layer.1.layer_norm.weight : Float(4096:1),
      %encoder.block.1.layer.0.layer_norm.weight : Float(4096:1),
      %encoder.block.1.layer.1.layer_norm.weight : Float(4096:1),
      %encoder.block.2.layer.0.layer_norm.weight : Float(4096:1),
      %encoder.block.2.layer.1.layer_norm.weight : Float(4096:1),
      %encoder.block.3.layer.0.layer_norm.weight : Float(4096:1),
      %encoder.block.3.layer.1.layer_norm.weight : Float(4096:1),
      %encoder.block.4.layer.0.layer_norm.weight : Float(4096:1),
      %encoder.block.4.layer.1.layer_norm.weight : Float(4096:1),
      %encoder.block.5.layer.0.layer_norm.weight : Float(4096:1),
      %encoder.block.5.layer.1.layer_norm.weight : Float(4096:1),
      %encoder.block.6.layer.

In [41]:
!ls

bert			       test_tf_encoder.ipynb
configuration_enc_dec.py       test_tokenizer.ipynb
convert_multi_to_single.ipynb  tokenization_enc_dec.py
model.py		       to_pytorch_encoder_only.ipynb
onnx			       to_pytorch.ipynb
__pycache__		       to_tensorflow.ipynb
README.md		       vocab.txt
test_model_generation.ipynb


In [53]:
!du -sh "onnx"

18G	onnx
