In [1]:
import re

import torch
import numpy as np
import tensorflow as tf

from tf2gpt.model import GPT

In [2]:
!du -sh Pangu-alpha_2.6B_mgt/iter_0001000/mp_rank_00/model_optim_rng.pt

4.9G	Pangu-alpha_2.6B_mgt/iter_0001000/mp_rank_00/model_optim_rng.pt


In [3]:
m0 = torch.load('Pangu-alpha_2.6B_mgt/iter_0001000/mp_rank_00/model_optim_rng.pt', map_location='cpu')

In [4]:
m0_weights = []

def extract_weight(w = m0['model'], root=''):
    for k, v in w.items():
        if isinstance(v, dict):
            extract_weight(v, root + '.' + k)
        elif isinstance(v, torch.Tensor):
            k = root + '.' + k
            k = k.replace('.language_model.', '')
            k = k.replace('.topQueryLayer.', '.layers.31.')
            m0_weights.append((
                k,
                v
            ))
        else:
            print('what?', type(v))

In [5]:
extract_weight()

In [6]:
len(m0_weights)

517

In [7]:
pangu_weights = {}
for k, v in m0_weights:
    print(k, v.shape)
    pangu_weights[k] = v

embedding.word_embeddings.weight torch.Size([40064, 2560])
embedding.position_embeddings.weight torch.Size([1024, 2560])
topQueryEmbedding.top_query_embeddings.weight torch.Size([1024, 2560])
transformer.layers.0.input_layernorm.weight torch.Size([2560])
transformer.layers.0.input_layernorm.bias torch.Size([2560])
transformer.layers.0.attention.query.weight torch.Size([2560, 2560])
transformer.layers.0.attention.query.bias torch.Size([2560])
transformer.layers.0.attention.key.weight torch.Size([2560, 2560])
transformer.layers.0.attention.key.bias torch.Size([2560])
transformer.layers.0.attention.value.weight torch.Size([2560, 2560])
transformer.layers.0.attention.value.bias torch.Size([2560])
transformer.layers.0.attention.dense.weight torch.Size([2560, 2560])
transformer.layers.0.attention.dense.bias torch.Size([2560])
transformer.layers.0.post_attention_layernorm.weight torch.Size([2560])
transformer.layers.0.post_attention_layernorm.bias torch.Size([2560])
transformer.layers.0.mlp.d

In [8]:
gpt = GPT(
    vocab_size=40_064,
    layer_size=32,
    block_size=1024,
    embedding_dropout=0.0,
    embedding_size=2560,
    num_attention_heads=32,
    attention_dropout=0.0,
    residual_dropout=0.0)

In [9]:
print(gpt(tf.constant([[1]])).shape)

(1, 1, 40064)


In [10]:
for x in gpt.weights:
    if 'gpt/layer' in x.name:
        if 'gpt/layer00' in x.name:
            print(x.name, x.shape)
    else:
        print(x.name, x.shape)

gpt/embedding/embeddings:0 (40064, 2560)
position_embeddings:0 (1024, 2560)
top_query:0 (1024, 2560)
gpt/layer00/attention/query_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/query_layer/bias:0 (2560,)
gpt/layer00/attention/key_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/key_layer/bias:0 (2560,)
gpt/layer00/attention/value_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/value_layer/bias:0 (2560,)
gpt/layer00/attention/context_projection_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/context_projection_layer/bias:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln0/gamma:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln0/beta:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln1/gamma:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln1/beta:0 (2560,)
gpt/layer00/intermediate/kernel:0 (2560, 10240)
gpt/layer00/intermediate/bias:0 (10240,)
gpt/layer00/output/kernel:0 (10240, 2560)
gpt/layer00/output/bias:0 (2560,)
gpt/LayerNorm_final_norm/gamma:0 (2560,)
gpt/LayerNorm_final_norm/beta:0 (2560,)


In [20]:
new_weights = []

for x in gpt.weights:
    xs = tuple(x.shape)

    if 'gpt/embedding/embeddings:' in x.name:
        pname = 'embedding.word_embeddings.weight'
        w = pangu_weights[pname]
        assert w.shape == (4_0064, 2560)
        new_weights.append((x.name, xs, pname, w))

    elif 'position_embeddings' in x.name:
        pname = 'embedding.position_embeddings.weight'
        w = pangu_weights[pname]
        assert xs == w.shape
        new_weights.append((x.name, xs, pname, w))
    
    elif 'top_query' in x.name:
        pname = 'topQueryEmbedding.top_query_embeddings.weight'
        w = pangu_weights[pname]
        assert xs == w.shape
        new_weights.append((x.name, xs, pname, w))

    elif 'gpt/layer' in x.name:
        n_layer = int(x.name[len('gpt/layer'):][:2])
        if 'query_layer/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.query.weight'
            w = pangu_weights[pname]
            w = np.transpose(w)
            assert xs == w.shape
            new_weights.append((x.name, xs, pname, w))
        elif 'key_layer/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.key.weight'
            w = pangu_weights[pname]
            w = np.transpose(w)
            assert xs == w.shape
            new_weights.append((x.name, xs, pname, w))
        elif 'value_layer/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.value.weight'
            w = pangu_weights[pname]
            w = np.transpose(w)
            assert xs == w.shape
            new_weights.append((x.name, xs, pname, w))
        elif 'query_layer/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.query.bias'
            w = pangu_weights[pname]
            assert xs == w.shape
            new_weights.append((x.name, xs, pname, w))
        elif 'key_layer/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.key.bias'
            w = pangu_weights[pname]
            assert xs == w.shape
            new_weights.append((x.name, xs, pname, w))
        elif 'value_layer/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.value.bias'
            w = pangu_weights[pname]
            assert xs == w.shape
            new_weights.append((x.name, xs, pname, w))

        elif 'attention/context_projection_layer/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.dense.weight'
            w = pangu_weights[pname]
            w = np.transpose(w)
            assert w.shape == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'attention/context_projection_layer/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.dense.bias'
            w = pangu_weights[pname]
            assert w.shape == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln0/gamma' in x.name:
            pname = f'transformer.layers.{n_layer}.input_layernorm.weight'
            w = pangu_weights[pname]
            assert w.shape == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif 'LayerNorm_mlp_ln1/gamma' in x.name:
            pname = f'transformer.layers.{n_layer}.post_attention_layernorm.weight'
            w = pangu_weights[pname]
            assert w.shape == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif 'LayerNorm_mlp_ln0/beta' in x.name:
            pname = f'transformer.layers.{n_layer}.input_layernorm.bias'
            w = pangu_weights[pname]
            assert w.shape == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif 'LayerNorm_mlp_ln1/beta' in x.name:
            pname = f'transformer.layers.{n_layer}.post_attention_layernorm.bias'
            w = pangu_weights[pname]
            assert w.shape == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif 'intermediate/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.weight'
            w = pangu_weights[pname]
            w = np.transpose(w)
            assert w.shape == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif 'intermediate/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.bias'
            w = pangu_weights[pname]
            assert w.shape == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif '/output/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_4h_to_h.weight'
            w = pangu_weights[pname]
            w = np.transpose(w)
            assert w.shape == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif '/output/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_4h_to_h.bias'
            w = pangu_weights[pname]
            assert w.shape == xs
            new_weights.append((x.name, x.shape, pname, w))

        else:
            print('BAD', x.name, xs)
            break
    elif 'gpt/LayerNorm_final_norm/gamma' in x.name:
        pname = 'transformer.final_layernorm.weight'
        w = pangu_weights[pname]
        assert w.shape == xs
        new_weights.append((x.name, x.shape, pname, w))

    elif 'gpt/LayerNorm_final_norm/beta' in x.name:
        pname = 'transformer.final_layernorm.bias'
        w = pangu_weights[pname]
        assert w.shape == xs
        new_weights.append((x.name, x.shape, pname, w))

    else:
        print('BAD', x.name, xs)
        break

In [21]:
assert len(new_weights) == len(gpt.weights)
for x in new_weights:
    assert tuple(x[1]) == x[-1].shape

In [22]:
len(gpt.weights)

517

In [23]:
gpt.set_weights([x[-1] for x in new_weights])

In [15]:
from tokenization_jieba import JIEBATokenizer
cbpe = JIEBATokenizer(
    'PanGu-Alpha-GPU/panguAlpha_pytorch/megatron/tokenizer/bpe_4w_pcl/vocab.vocab',
    'PanGu-Alpha-GPU/panguAlpha_pytorch/megatron/tokenizer/bpe_4w_pcl/vocab.model')

In [16]:
cbpe.vocab_size

40000

In [24]:
ids = cbpe.encode('青椒肉丝的做法：')

for i in range(10):
    output = gpt(tf.constant([ids]))
    nid = np.argmax(output[0, -1])
    ids += [int(nid)]
    print(i, cbpe.decode(ids))

0 青椒肉丝的做法:是
1 青椒肉丝的做法:是青
2 青椒肉丝的做法:是青椒
3 青椒肉丝的做法:是青椒洗净
4 青椒肉丝的做法:是青椒洗净切
5 青椒肉丝的做法:是青椒洗净切丝
6 青椒肉丝的做法:是青椒洗净切丝<eot>
7 青椒肉丝的做法:是青椒洗净切丝<eot>青
8 青椒肉丝的做法:是青椒洗净切丝<eot>青椒
9 青椒肉丝的做法:是青椒洗净切丝<eot>青椒肉


In [25]:
@tf.function
def batch_gather(a, b):
    return tf.gather(a, b, batch_dims=1)


@tf.function
def top_k_top_p_sample(logits, num_samples=1, top_k=0, p=0.95):
    batch_size, vocab_size = logits.shape
    probs = tf.nn.softmax(logits, axis=-1)
    
    # [batch_size, vocab_perm]
    indices = tf.argsort(probs, direction='DESCENDING')
    logits_to_use = batch_gather(logits, indices)
    cumulative_probabilities = tf.math.cumsum(batch_gather(probs, indices), axis=-1, exclusive=False)

    # find the top pth index to cut off. careful we don't want to cutoff everything!
    # result will be [batch_size, vocab_perm]
    if p > 0.0:
        exclude_mask = tf.logical_not(
            tf.logical_or(cumulative_probabilities < p, tf.range(vocab_size)[None] < 1))
        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = logits_to_use - tf.cast(exclude_mask, tf.float32) * 1e10
    
    if top_k > 0:
        logits_to_use = logits_to_use - tf.cast(
            tf.argsort(logits_to_use, direction='DESCENDING') >= top_k,
            dtype=tf.float32
        ) * 1e10
    
    sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
    sample = batch_gather(indices, sample_perm)

    return tf.cast(sample, tf.int64)

@tf.function
def serve(inputs):
    return gpt(inputs, kv_cache=None, use_cache=True)


@tf.function
def serve_cache(inputs, kv_cache):
    return gpt(inputs, kv_cache=kv_cache, use_cache=True)

serve_concrete = serve.get_concrete_function(
    tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp")
)

layer_size = 32
attention_head = 32
embedding_size = 2560

serve_cache_concrete = serve_cache.get_concrete_function(
    tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp"),
    tf.TensorSpec(shape=[
        layer_size, None, 2, attention_head,
        None, embedding_size // attention_head
    ], dtype=tf.float32, name="kv_cache")
)

@tf.function
def sample(initial_inputs, length, top_k, top_p, temperature):
    layer_size = 32
    embedding_size = 2560
    attention_head = 32

    i = tf.constant(0, dtype=tf.int64)
    initial_logits, kv_cache = serve(initial_inputs)
    logits_with_temperature = initial_logits[:, -1, :]
    if temperature > 0.0:
        logits_with_temperature /= temperature
    inputs = top_k_top_p_sample(logits_with_temperature, 1, top_k, top_p)
    stores = tf.concat([initial_inputs, inputs], axis=1)

    def _cond(i, inputs, kv_cache, stores):
        return i < length

    def _body(i, inputs, kv_cache, stores):
        new_logits, new_kv_cache = serve_cache(inputs, kv_cache)
        logits_with_temperature = new_logits[:, -1, :]
        if temperature > 0.0:
            logits_with_temperature /= temperature
        new_inputs = top_k_top_p_sample(logits_with_temperature, 1, top_k, top_p)
        new_stores = tf.concat([stores, new_inputs], axis=-1)
        new_kv_cache = tf.concat([
            kv_cache,
            new_kv_cache
        ], axis=-2)
        new_i = i + 1
        return [new_i, new_inputs, new_kv_cache, new_stores]

    result = tf.while_loop(
        _cond, _body,
        loop_vars=[i, inputs, kv_cache, stores],
        shape_invariants=[
            tf.TensorShape(None),
            tf.TensorShape([None, None]),
            tf.TensorShape([
                layer_size, None, 2,
                attention_head, None,
                embedding_size // attention_head
            ]),
            tf.TensorShape([
                None, None
            ])
        ]
    )
    return result[-1]

In [26]:
ids = cbpe.encode('今天天气不错')

ret = sample(
    tf.constant([ids], dtype=tf.int64),
    tf.constant(15, dtype=tf.int64),
    tf.constant(15, dtype=tf.int32),
    tf.constant(0.95, dtype=tf.float32),
    tf.constant(0.9, dtype=tf.float32)
)
print(ret)
print(cbpe.decode(ret.numpy().tolist()[0]))

tf.Tensor(
[[  465   235   464  1123    10    21    18 32636  3001 21507    13     3
   3001  3001 32504    10  2448    24    58   201]], shape=(1, 20), dtype=int64)
今天天气不错,我在树下吹吹风
吹吹凉风,心情也会跟


In [27]:
gpt.save('./pangu-2.6B-tf2', include_optimizer=False, signatures={
    'serving_default': sample.get_concrete_function(
        tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp"),
        tf.TensorSpec(shape=[None,], dtype=tf.int64, name="length"),
        tf.TensorSpec(shape=[None,], dtype=tf.int32, name="top_k"),
        tf.TensorSpec(shape=[None,], dtype=tf.float32, name="top_p"),
        tf.TensorSpec(shape=[None,], dtype=tf.float32, name="temperature")
    )
})





INFO:tensorflow:Assets written to: ./pangu-2.6B-tf2/assets


INFO:tensorflow:Assets written to: ./pangu-2.6B-tf2/assets
