In [15]:
from collections import OrderedDict

import torch
import tensorflow as tf
import numpy as np

from transformers import T5EncoderModel, T5ForConditionalGeneration, TFT5Model
from transformers import T5Config

In [2]:
from tokenization_enc_dec import EncDecTokenizer
tokenizer = EncDecTokenizer('./vocab.txt')

In [3]:
config = T5Config(
    vocab_size=26240,
#     n_positions=self.n_positions,
    d_model=4096,
    d_ff=10240,
    d_kv=4096 // 64,
    num_layers=24,
    num_heads=64,
    relative_attention_num_buckets=32,
    dropout_rate=0.0,
    initializer_factor=1.0,
    eos_token_id=tokenizer.eod_id,
    bos_token_id=tokenizer.pad_id,
    pad_token_id=tokenizer.pad_id,
    decoder_start_token_id=tokenizer.pad_id,
    feed_forward_proj='gated-gelu',
    tie_word_embeddings=False
)

In [4]:
model = T5ForConditionalGeneration(config)

In [6]:
out = model(input_ids=torch.LongTensor([[1]]), decoder_input_ids=torch.LongTensor([[1]]))

In [7]:
out.keys()

odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state'])

In [9]:
len(list(model.parameters()))

558

In [10]:
def get_weight(name):
    return state_dict[name].numpy()

encoder_names0 = [
    'encoder.block.{}.layer.0.SelfAttention.q.weight',
    'encoder.block.{}.layer.0.SelfAttention.k.weight',
    'encoder.block.{}.layer.0.SelfAttention.v.weight',
    'encoder.block.{}.layer.0.SelfAttention.o.weight',
    'encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight',
    'encoder.block.{}.layer.0.layer_norm.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_0.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_1.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wo.weight',
    'encoder.block.{}.layer.1.layer_norm.weight',
]

decoder_names0 = [
    'decoder.block.{}.layer.0.SelfAttention.q.weight',
    'decoder.block.{}.layer.0.SelfAttention.k.weight',
    'decoder.block.{}.layer.0.SelfAttention.v.weight',
    'decoder.block.{}.layer.0.SelfAttention.o.weight',
    'decoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight',
    'decoder.block.{}.layer.0.layer_norm.weight',
    'decoder.block.{}.layer.1.EncDecAttention.q.weight',
    'decoder.block.{}.layer.1.EncDecAttention.k.weight',
    'decoder.block.{}.layer.1.EncDecAttention.v.weight',
    'decoder.block.{}.layer.1.EncDecAttention.o.weight',
    'decoder.block.{}.layer.1.layer_norm.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_0.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_1.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wo.weight',
    'decoder.block.{}.layer.2.layer_norm.weight',
]

encoder_names = [
    'encoder.block.{}.layer.0.SelfAttention.q.weight',
    'encoder.block.{}.layer.0.SelfAttention.k.weight',
    'encoder.block.{}.layer.0.SelfAttention.v.weight',
    'encoder.block.{}.layer.0.SelfAttention.o.weight',
    'encoder.block.{}.layer.0.layer_norm.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_0.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wi_1.weight',
    'encoder.block.{}.layer.1.DenseReluDense.wo.weight',
    'encoder.block.{}.layer.1.layer_norm.weight',
]

decoder_names = [
    'decoder.block.{}.layer.0.SelfAttention.q.weight',
    'decoder.block.{}.layer.0.SelfAttention.k.weight',
    'decoder.block.{}.layer.0.SelfAttention.v.weight',
    'decoder.block.{}.layer.0.SelfAttention.o.weight',
    'decoder.block.{}.layer.0.layer_norm.weight',
    'decoder.block.{}.layer.1.EncDecAttention.q.weight',
    'decoder.block.{}.layer.1.EncDecAttention.k.weight',
    'decoder.block.{}.layer.1.EncDecAttention.v.weight',
    'decoder.block.{}.layer.1.EncDecAttention.o.weight',
    'decoder.block.{}.layer.1.layer_norm.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_0.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wi_1.weight',
    'decoder.block.{}.layer.2.DenseReluDense.wo.weight',
    'decoder.block.{}.layer.2.layer_norm.weight',
]

def get_block_weight(n, t='encoder', dim=4096):
    weights = []
    for k, v in state_dict.items():
        if t in k and f'blocks.{n}.' in k:
            # pytorch和tensorflow版本的weights是矩阵转置的
            w = v.numpy()
            if 'self_attn.project' in k:
                w0, w1, w2 = w[:dim, :], w[dim:dim*2, :], w[dim*2:, :]
#                 w0 = np.transpose(w0)
#                 w1 = np.transpose(w1)
#                 w2 = np.transpose(w2)
                weights.append((k, w0))
                weights.append((k, w1))
                weights.append((k, w2))
            elif 'cross_attn.project_q' in k:
#                 w = np.transpose(w)
                weights.append((k, w))
            elif 'cross_attn.project_kv' in k:
                w0, w1 = w[:dim, :], w[dim:, :]
#                 w0 = np.transpose(w0)
#                 w1 = np.transpose(w1)
                weights.append((k, w0))
                weights.append((k, w1))
            else:
#                 if 'dense' in k:
#                     w = np.transpose(w)
                weights.append((k, w))
    if 'relative_attention_bias' in weights[3][0]:
        weights[3], weights[4] = weights[4], weights[3]
    weights = [x[1] for x in weights]
    if 'encoder' == t:
        weights_dict = OrderedDict()
        for k, v in zip(encoder_names0 if n == 0 else encoder_names, weights):
            weights_dict[k.format(n)] = v
        weights = weights_dict
    else:
        weights_dict = OrderedDict()
        for k, v in zip(decoder_names0 if n == 0 else decoder_names, weights):
            weights_dict[k.format(n)] = v
        weights = weights_dict
    return weights

In [11]:
state_dict = torch.load('../converted.zip')

In [16]:
model_new_weights = OrderedDict()
model_new_weights['shared.weight'] = get_weight('word_embeds.weight')
model_new_weights['encoder.embed_tokens.weight'] = get_weight('encoder.word_embeds.weight')
for i in range(24):
    for k, v in get_block_weight(i, t='encoder').items():
        model_new_weights[k] = v

model_new_weights['encoder.final_layer_norm.weight'] = get_weight('encoder.final_layernorm.weight')

for i in range(24):
    for k, v in get_block_weight(i, t='decoder').items():
        model_new_weights[k] = v

model_new_weights['decoder.final_layer_norm.weight'] = get_weight('decoder.final_layernorm.weight')

model_new_weights['decoder.embed_tokens.weight'] = get_weight('decoder.word_embeds.weight')
model_new_weights['lm_head.weight'] = get_weight('lm_head.weight')

In [17]:
len(model_new_weights)

560

In [24]:
set(model.state_dict().keys()) - set(model_new_weights.keys())

set()

In [25]:
set(model_new_weights.keys()) - set(model.state_dict().keys())

set()

In [30]:
for k, v in state_dict.items():
    print(k, v.shape)

word_embeds.weight torch.Size([26240, 4096])
lm_head.weight torch.Size([26240, 4096])
encoder.word_embeds.weight torch.Size([26240, 4096])
encoder.final_layernorm.weight torch.Size([4096])
encoder.blocks.0.self_attn.self_attn.project.weight torch.Size([12288, 4096])
encoder.blocks.0.self_attn.self_attn.relative_attention_bias.weight torch.Size([32, 64])
encoder.blocks.0.self_attn.self_attn.dense.weight torch.Size([4096, 4096])
encoder.blocks.0.self_attn.layer_norm.weight torch.Size([4096])
encoder.blocks.0.ff.dense_relu_dense.wi_0.weight torch.Size([10240, 4096])
encoder.blocks.0.ff.dense_relu_dense.wi_1.weight torch.Size([10240, 4096])
encoder.blocks.0.ff.dense_relu_dense.wo.weight torch.Size([4096, 10240])
encoder.blocks.0.ff.layer_norm.weight torch.Size([4096])
encoder.blocks.1.self_attn.self_attn.project.weight torch.Size([12288, 4096])
encoder.blocks.1.self_attn.self_attn.dense.weight torch.Size([4096, 4096])
encoder.blocks.1.self_attn.layer_norm.weight torch.Size([4096])
encoder.

In [31]:
for k, v in model.state_dict().items():
    print(k, v.shape)

shared.weight torch.Size([26240, 4096])
encoder.embed_tokens.weight torch.Size([26240, 4096])
encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([4096, 4096])
encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([4096, 4096])
encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([4096, 4096])
encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([4096, 4096])
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 64])
encoder.block.0.layer.0.layer_norm.weight torch.Size([4096])
encoder.block.0.layer.1.DenseReluDense.wi_0.weight torch.Size([10240, 4096])
encoder.block.0.layer.1.DenseReluDense.wi_1.weight torch.Size([10240, 4096])
encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([4096, 10240])
encoder.block.0.layer.1.layer_norm.weight torch.Size([4096])
encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([4096, 4096])
encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([4096, 4096])
encoder.block.1.layer.0.SelfAtten

In [28]:
for k, v in model_new_weights.items():
    print(k, v.shape)

shared.weight (26240, 4096)
encoder.embed_tokens.weight (26240, 4096)
encoder.block.0.layer.0.SelfAttention.q.weight (4096, 4096)
encoder.block.0.layer.0.SelfAttention.k.weight (4096, 4096)
encoder.block.0.layer.0.SelfAttention.v.weight (4096, 4096)
encoder.block.0.layer.0.SelfAttention.o.weight (4096, 4096)
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight (32, 64)
encoder.block.0.layer.0.layer_norm.weight (4096,)
encoder.block.0.layer.1.DenseReluDense.wi_0.weight (10240, 4096)
encoder.block.0.layer.1.DenseReluDense.wi_1.weight (10240, 4096)
encoder.block.0.layer.1.DenseReluDense.wo.weight (4096, 10240)
encoder.block.0.layer.1.layer_norm.weight (4096,)
encoder.block.1.layer.0.SelfAttention.q.weight (4096, 4096)
encoder.block.1.layer.0.SelfAttention.k.weight (4096, 4096)
encoder.block.1.layer.0.SelfAttention.v.weight (4096, 4096)
encoder.block.1.layer.0.SelfAttention.o.weight (4096, 4096)
encoder.block.1.layer.0.layer_norm.weight (4096,)
encoder.block.1.layer.1.Dense

In [16]:
# assert len(model_new_weights) == len(model.variables)

In [32]:
model.load_state_dict({k: torch.from_numpy(v) for k, v in model_new_weights.items()})

<All keys matched successfully>

In [35]:
input_text = '''当地时间9月6日是美国劳工节，但就在这一天，上千万美国劳动者却陷入新的困境。因为美国政府为疫情期间失业者提供的主要救助同日到期，而且白宫表示没有进一步延长救助的计划。
在德尔塔变异株已把美国推入新一轮疫情的背景下，失业救济的突然“断供”意味着有上千万美国人将全部或部分失去他们的生活来源。'''
input_ids = torch.LongTensor([tokenizer.encode(input_text) + [tokenizer.get_sentinel_id(0)]])

In [36]:
out = model.generate(
    input_ids,
    max_length=50,
    decoder_start_token_id=1,
    do_sample=True,
    top_p=0.95,
    top_k=20,
    bad_words_ids=[[x] for x in range(26050, tokenizer.vocab_size)]
)

In [39]:
out_text = tokenizer.decode(out.detach().numpy()[0].tolist()[1:])

In [40]:
print(out_text)

加油的希望被打破了。
为什么会这样？ 
 “美国政府将在今明两天“断供”？
 
美国政府为疫情期间失业救被断  
 
美国政府


In [42]:
input_text = '''当地时间9月6日是美国劳工节，但就在这一天，上千万美国劳动者却陷入新的困境。因为美国政府为疫情期间失业者提供的主要救助同日到期，而且白宫表示没有进一步延长救助的计划。
在德尔塔变异株已把美国推入新一轮疫情的背景下，失业救济的突然“断供”意味着有上千万美国人将全部或部分失去他们的生活来源。'''
output_text = '''美国'''
input_ids = torch.LongTensor([tokenizer.encode(input_text)])

for i in range(20):
    decoder_input_ids = torch.LongTensor([[1, tokenizer.get_sentinel_id(0)] + tokenizer.encode(output_text)])
    out = model(input_ids, decoder_input_ids=decoder_input_ids)
    t = out['logits'][:, -1, :26050].detach().numpy()
    next_token = tokenizer.decode([np.argmax(t, -1)[0]])
    # print(next_token)
    output_text += next_token
    print(output_text)

美国劳动
美国劳动者
美国劳动者的
美国劳动者的新
美国劳动者的新困
美国劳动者的新困境
美国劳动者的新困境。
美国劳动者的新困境。

美国劳动者的新困境。
美国
美国劳动者的新困境。
美国政府
美国劳动者的新困境。
美国政府为
美国劳动者的新困境。
美国政府为疫情
美国劳动者的新困境。
美国政府为疫情期
美国劳动者的新困境。
美国政府为疫情期间
美国劳动者的新困境。
美国政府为疫情期间失业
美国劳动者的新困境。
美国政府为疫情期间失业者
美国劳动者的新困境。
美国政府为疫情期间失业者提供
美国劳动者的新困境。
美国政府为疫情期间失业者提供的
美国劳动者的新困境。
美国政府为疫情期间失业者提供的主要
美国劳动者的新困境。
美国政府为疫情期间失业者提供的主要救助
