In [1]:
import re
from collections import OrderedDict

import torch
import numpy as np
from tqdm import tqdm
from transformers import T5ForConditionalGeneration
from transformers import T5Config

from tokenization_enc_dec import EncDecTokenizer

In [3]:
tokenizer = EncDecTokenizer('../zhiyuan/cpm2.1/vocab.txt')

In [6]:
config = T5Config(
    vocab_size=26240,
#     n_positions=self.n_positions,
    d_model=4096,
    d_ff=10240,
    d_kv=4096 // 64,
    num_layers=24,
    num_heads=64,
    relative_attention_num_buckets=32,
    dropout_rate=0.0,
    initializer_factor=1.0,
    eos_token_id=tokenizer.eod_id,
    bos_token_id=tokenizer.pad_id,
    pad_token_id=tokenizer.pad_id,
    decoder_start_token_id=tokenizer.pad_id,
    feed_forward_proj='gated-gelu',
    tie_word_embeddings=False
)

In [7]:
model = T5ForConditionalGeneration(config)

In [8]:
out = model(input_ids=torch.LongTensor([[1]]), decoder_input_ids=torch.LongTensor([[1]]))

In [9]:
out.keys()

odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state'])

In [10]:
import struct
import numpy as np
from collections import OrderedDict


def load_dtype(fp):
    v = struct.unpack("B", fp.read(1))[0]
    if v == 0:
        return np.int8
    elif v == 1:
        return np.float16
    elif v == 2:
        return np.float32
    else:
        raise TypeError("Unknown dtype %d" % v)

def load_string(fp):
    size = struct.unpack("I", fp.read(4))[0]
    v = fp.read(size)
    return v.decode("utf-8")

def load_tuple(fp):
    dim_tuple = struct.unpack("B", fp.read(1))[0]
    ret = []
    for _ in range(dim_tuple):
        ret.append(struct.unpack("I", fp.read(4))[0]) 
    return tuple(ret)

def load_parameter(fp):    
    shape = load_tuple(fp)
    value_size = struct.unpack("I", fp.read(4))[0]
    dtype = load_dtype(fp)
    value = fp.read(value_size)
    return shape, value, dtype

def load(fp, parent_name=''):
    num_parameters, num_sub_layers = struct.unpack("II", fp.read(8))
    parameters = []

    for _ in range(num_parameters):
        name = load_string(fp)
        shape, value, dtype = load_parameter(fp)
        parameters.append((parent_name + '.' + name, np.frombuffer(value, dtype).reshape(shape)))
    for _ in range(num_sub_layers):
        name = load_string(fp)
        parameters += load(fp, parent_name + '.' + name)
    return parameters

In [11]:
with open('../zhiyuan/cpm2.1/checkpoint.pt', 'rb') as fp:
    parameters = load(fp)

In [12]:
pindex = {x[0]: x[1] for x in parameters}

In [13]:
npara = {}
for name, value in tqdm(parameters):
    if '_scale' not in name:
        has_scale = name + '_scale'
        if has_scale in pindex:
            scale = pindex[has_scale]
            value = value.astype(np.float16) * scale
        npara[name] = value

100%|██████████| 704/704 [03:22<00:00,  3.47it/s]


In [14]:
def get_encoder(n):
    params = []
    for k, v in npara.items():
        if n == 0 and '.encoder_position_bias.embedding.weight' in k:
            params.append((
                'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight',
                v,
            ))
        if f'.encoder.{n}.' in k:
            if 'self_attention.w_project_qkv' in k:
                params.append((
                    f'encoder.block.{n}.layer.0.SelfAttention.q.weight',
                    v[0]
                ))
                params.append((
                    f'encoder.block.{n}.layer.0.SelfAttention.k.weight',
                    v[1]
                ))
                params.append((
                    f'encoder.block.{n}.layer.0.SelfAttention.v.weight',
                    v[2]
                ))
            if 'self_attention.w_out' in k:
                params.append((
                    f'encoder.block.{n}.layer.0.SelfAttention.o.weight',
                    v,
                ))
            if 'layer_nrom_before_self_attn.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.0.layer_norm.weight',
                    v,
                ))
            if 'dense_gelu_dense.wi_0.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.1.DenseReluDense.wi_0.weight',
                    v,
                ))
            if 'dense_gelu_dense.wi_1.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.1.DenseReluDense.wi_1.weight',
                    v,
                ))
            if 'dense_gelu_dense.wo.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.1.DenseReluDense.wo.weight',
                    v,
                ))
            if 'layer_nrom_before_ff.weight' in k:
                params.append((
                    f'encoder.block.{n}.layer.1.layer_norm.weight',
                    v,
                ))
    return params

In [15]:
def get_decoder(n):
    params = []
    for k, v in npara.items():
        if n == 0 and '.decoder_position_bias.embedding.weight' in k:
            params.append((
                'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight',
                v,
            ))
        if f'.decoder.{n}.' in k:
            if 'self_attention.w_project_qkv' in k:
                params.append((
                    f'decoder.block.{n}.layer.0.SelfAttention.q.weight',
                    v[0]
                ))
                params.append((
                    f'decoder.block.{n}.layer.0.SelfAttention.k.weight',
                    v[1]
                ))
                params.append((
                    f'decoder.block.{n}.layer.0.SelfAttention.v.weight',
                    v[2]
                ))
            if 'self_attention.w_out' in k:
                params.append((
                    f'decoder.block.{n}.layer.0.SelfAttention.o.weight',
                    v,
                ))
            if 'layer_nrom_before_self_attn.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.0.layer_norm.weight',
                    v,
                ))
                
            if '.cross_attention.w_project_q' in k:
                params.append((
                    f'decoder.block.{n}.layer.1.EncDecAttention.q.weight',
                    v,
                ))
            if '.cross_attention.w_out' in k:
                params.append((
                    f'decoder.block.{n}.layer.1.EncDecAttention.o.weight',
                    v,
                ))
            if 'layer_nrom_before_cross_attn.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.1.layer_norm.weight',
                    v,
                ))

            if 'dense_gelu_dense.wi_0.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.2.DenseReluDense.wi_0.weight',
                    v,
                ))
            if 'dense_gelu_dense.wi_1.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.2.DenseReluDense.wi_1.weight',
                    v,
                ))
            if 'dense_gelu_dense.wo.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.2.DenseReluDense.wo.weight',
                    v,
                ))
            if 'layer_nrom_before_ff.weight' in k:
                params.append((
                    f'decoder.block.{n}.layer.2.layer_norm.weight',
                    v,
                ))
#     params.append((
#         f'decoder.block.{n}.layer.1.EncDecAttention.k.weight',
#         npara[f'.encoder.{n}.self_attention.w_project_qkv'][1]
#     ))
#     params.append((
#         f'decoder.block.{n}.layer.1.EncDecAttention.v.weight',
#         npara[f'.encoder.{n}.self_attention.w_project_qkv'][2]
#     ))
    params.append((
        f'decoder.block.{n}.layer.1.EncDecAttention.k.weight',
        npara['.encoder_kv.w_project_kv'][n][0]
    ))
    params.append((
        f'decoder.block.{n}.layer.1.EncDecAttention.v.weight',
        npara['.encoder_kv.w_project_kv'][n][1]
    ))
    return params

In [16]:
new_state_dict = []
new_state_dict.append((
    'shared.weight',
    npara['.input_embedding.weight'],
))
new_state_dict.append((
    'encoder.embed_tokens.weight',
    npara['.input_embedding.weight'],
))

for i in range(24):
    new_state_dict += get_encoder(i)

new_state_dict.append((
    'encoder.final_layer_norm.weight',
    npara['.encoder_final_layer_nrom.weight'],
))

new_state_dict.append((
    'decoder.embed_tokens.weight',
    npara['.input_embedding.weight'],
))

for i in range(24):
    new_state_dict += get_decoder(i)

new_state_dict.append((
    'decoder.final_layer_norm.weight',
    npara['.decoder_final_layer_nrom.weight'],
))

new_state_dict.append((
    'lm_head.weight',
    npara['.lm_head.weight'],
))

In [17]:
len(new_state_dict)

560

In [18]:
len(model.state_dict())

560

In [19]:
model.load_state_dict({
    k: torch.from_numpy(v)
    for k, v in new_state_dict
})

  k: torch.from_numpy(v)


<All keys matched successfully>

In [20]:
model = model.eval()

In [35]:
input_text = '''当地时间9月6日是美国劳工节，但就在这一天，上千万美国劳动者却陷入新的困境。因为美国政府为疫情期间失业者提供的主要救助同日到期，而且白宫表示没有进一步延长救助的计划。
在德尔塔变异株已把美国推入新一轮疫情的背景下，失业救济的突然“断供”意味着有上千万美国人将全部或部分失去他们的生活来源。'''
output_text = '''美国'''
input_ids = torch.LongTensor([tokenizer.encode(input_text) + [tokenizer.get_sentinel_id(189)]])

for i in range(20):
    decoder_input_ids = torch.LongTensor([[tokenizer.get_sentinel_id(189)]])
    out = model(input_ids, decoder_input_ids=decoder_input_ids)
    t = out['logits'][:, -1, :26050].detach().numpy()
    next_token = tokenizer.decode([np.argmax(t, -1)[0]])
    # print(next_token)
    output_text += next_token
    print(output_text)

美国，
美国，，
美国，，，
美国，，，，


KeyboardInterrupt: 