In [1]:
import re

import torch
import numpy as np
import tensorflow as tf

from tf2gpt.model import GPT

In [2]:
!ls 'CPM-distill/310000'

mp_rank_00_model_states.pt  mp_rank_01_model_states.pt


In [3]:
m0 = torch.load('./CPM-distill/310000/mp_rank_00_model_states.pt', map_location='cpu')
m1 = torch.load('./CPM-distill/310000/mp_rank_01_model_states.pt', map_location='cpu')

In [4]:
def find_weight(model, name):
    for n, w in model['module'].items():
        if name == n:
            return w, list(w.shape)

In [5]:
for n, w in m0['module'].items():
    if '.layers.' in n:
        if '.layers.0.' in n:
            print(n, w.shape)
    else:
        print(n, w.shape)

word_embeddings.weight torch.Size([15000, 768])
position_embeddings.weight torch.Size([1024, 768])
transformer.layers.0.input_layernorm.weight torch.Size([768])
transformer.layers.0.input_layernorm.bias torch.Size([768])
transformer.layers.0.attention.query_key_value.weight torch.Size([1152, 768])
transformer.layers.0.attention.query_key_value.bias torch.Size([1152])
transformer.layers.0.attention.dense.weight torch.Size([768, 384])
transformer.layers.0.attention.dense.bias torch.Size([768])
transformer.layers.0.post_attention_layernorm.weight torch.Size([768])
transformer.layers.0.post_attention_layernorm.bias torch.Size([768])
transformer.layers.0.mlp.dense_h_to_4h.weight torch.Size([1536, 768])
transformer.layers.0.mlp.dense_h_to_4h.bias torch.Size([1536])
transformer.layers.0.mlp.dense_4h_to_h.weight torch.Size([768, 1536])
transformer.layers.0.mlp.dense_4h_to_h.bias torch.Size([768])
transformer.final_layernorm.weight torch.Size([768])
transformer.final_layernorm.bias torch.Size([

In [6]:
find_weight(m0, 'word_embeddings.weight')[0].shape

torch.Size([15000, 768])

In [7]:
find_weight(m0, 'position_embeddings.weight')[0].shape

torch.Size([1024, 768])

In [8]:
# list(m0['module'].keys())

In [9]:
gpt = GPT(
    vocab_size=30_000,
    layer_size=12,
    block_size=1024,  # position embedding
    embedding_dropout=0.0,
    embedding_size=768,
    num_attention_heads=12,
    attention_dropout=0.0,
    residual_dropout=0.0)

In [10]:
print(gpt(tf.constant([[1]])).shape)

(1, 1, 30000)


In [11]:
# for x in gpt.weights:
#     assert x.dtype == tf.float16

In [12]:
for x in gpt.weights:
    if 'gpt/layer' in x.name:
        if 'gpt/layer00' in x.name:
            print(x.name, x.shape)
    else:
        print(x.name, x.shape)

gpt/embedding/embeddings:0 (30000, 768)
position_embeddings:0 (1024, 768)
gpt/layer00/attention/query_layer/kernel:0 (768, 768)
gpt/layer00/attention/query_layer/bias:0 (768,)
gpt/layer00/attention/key_layer/kernel:0 (768, 768)
gpt/layer00/attention/key_layer/bias:0 (768,)
gpt/layer00/attention/value_layer/kernel:0 (768, 768)
gpt/layer00/attention/value_layer/bias:0 (768,)
gpt/layer00/attention/context_projection_layer/kernel:0 (768, 768)
gpt/layer00/attention/context_projection_layer/bias:0 (768,)
gpt/layer00/LayerNorm_mlp_ln0/gamma:0 (768,)
gpt/layer00/LayerNorm_mlp_ln0/beta:0 (768,)
gpt/layer00/LayerNorm_mlp_ln1/gamma:0 (768,)
gpt/layer00/LayerNorm_mlp_ln1/beta:0 (768,)
gpt/layer00/intermediate/kernel:0 (768, 3072)
gpt/layer00/intermediate/bias:0 (3072,)
gpt/layer00/output/kernel:0 (3072, 768)
gpt/layer00/output/bias:0 (768,)
gpt/LayerNorm_final_norm/gamma:0 (768,)
gpt/LayerNorm_final_norm/beta:0 (768,)


In [13]:
n_layer=0
find_weight(m0, f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.weight')[0].shape

torch.Size([1536, 768])

In [16]:
new_weights = []

for x in gpt.weights:
    xs = list(x.shape)

    if 'gpt/embedding/embeddings:' in x.name:
        pname = 'word_embeddings.weight'
        w0, ws0 = find_weight(m0, pname)
        w1, ws1 = find_weight(m1, pname)
        assert ws0 == [15_000, 768]
        assert ws1 == [15_000, 768]
        w = np.concatenate([w0.numpy(), w1.numpy()])
        assert w.shape == (3_0000, 768)
        new_weights.append((x.name, xs, pname, w))

    elif 'position_embeddings' in x.name:
        pname = 'position_embeddings.weight'
        w0, ws0 = find_weight(m0, pname)
        assert xs == ws0
        w = w0.numpy()
        new_weights.append((x.name, xs, pname, w))

    elif 'gpt/layer' in x.name:
        n_layer = int(x.name[len('gpt/layer'):][:2])
        if 'query_layer' in x.name or 'key_layer' in x.name or 'value_layer' in x.name:

            if '/kernel' in x.name:
                pname = f'transformer.layers.{n_layer}.attention.query_key_value.weight'
                w0, ws0 = find_weight(m0, pname)
                w1, ws1 = find_weight(m1, pname)
                assert ws0 == [1152, 768]
                assert ws1 == [1152, 768]
                w = np.concatenate([w0.numpy(), w1.numpy()])
                w = np.transpose(w)
                if 'query_layer' in x.name:
                    w = np.concatenate([w0.numpy()[:384, :], w1.numpy()[:384, :]])
                elif 'key_layer' in x.name:
                    w = np.concatenate([w0.numpy()[384:384*2, :], w1.numpy()[384:384*2, :]])
                elif 'value_layer' in x.name:
                    w = np.concatenate([w0.numpy()[384*2:, :], w1.numpy()[384*2:, :]])
                w = np.transpose(w)
                assert w.shape == (768, 768)
                new_weights.append((x.name, xs, pname, w))

            elif '/bias' in x.name:
                pname = f'transformer.layers.{n_layer}.attention.query_key_value.bias'
                w0, ws0 = find_weight(m0, pname)
                w1, ws1 = find_weight(m1, pname)
                assert ws0 == [1152,]
                assert ws1 == [1152,]
                w = np.concatenate([w0.numpy(), w1.numpy()])
                if 'query_layer' in x.name:
                    w = np.concatenate([w0.numpy()[:384], w1.numpy()[:384]])
                elif 'key_layer' in x.name:
                    w = np.concatenate([w0.numpy()[384:384*2], w1.numpy()[384:384*2]])
                elif 'value_layer' in x.name:
                    w = np.concatenate([w0.numpy()[384*2:], w1.numpy()[384*2:]])
                assert w.shape == (768,)
                new_weights.append((x.name, xs, pname, w))

        elif 'attention/context_projection_layer/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.dense.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=1)
            w = np.transpose(w)
            assert w.shape == (768, 768)
            new_weights.append((x.name, xs, pname, w))

        elif 'attention/context_projection_layer/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.dense.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert w.shape == (768,)
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln0/gamma' in x.name:
            pname = f'transformer.layers.{n_layer}.input_layernorm.weight'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif 'LayerNorm_mlp_ln1/gamma' in x.name:
            pname = f'transformer.layers.{n_layer}.post_attention_layernorm.weight'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln0/beta' in x.name:
            pname = f'transformer.layers.{n_layer}.input_layernorm.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln1/beta' in x.name:
            pname = f'transformer.layers.{n_layer}.post_attention_layernorm.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'intermediate/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=0)
            w = np.transpose(w)
            assert w.shape == (768, 3072)
            new_weights.append((x.name, xs, pname, w))

        elif 'intermediate/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.bias'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=0)
            w = np.transpose(w)
            assert w.shape == (3072,)
            new_weights.append((x.name, xs, pname, w))

        elif '/output/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_4h_to_h.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=1)
            w = np.transpose(w)
            assert w.shape == (3072, 768)
            new_weights.append((x.name, xs, pname, w))

        elif '/output/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_4h_to_h.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert w.shape == (768,)
            new_weights.append((x.name, xs, pname, w))

        else:
            print('BAD', x.name, xs)
            break
    elif 'gpt/LayerNorm_final_norm/gamma' in x.name:
        pname = 'transformer.final_layernorm.weight'
        w0, ws0 = find_weight(m0, pname)
        w = w0.numpy()
        assert ws0 == xs
        new_weights.append((x.name, xs, pname, w))
    elif 'gpt/LayerNorm_final_norm/beta' in x.name:
        pname = 'transformer.final_layernorm.bias'
        w0, ws0 = find_weight(m0, pname)
        w = w0.numpy()
        assert ws0 == xs
        new_weights.append((x.name, xs, pname, w))
    else:
        print('BAD', x.name, xs)
        break

In [17]:
assert len(new_weights) == len(gpt.weights)

In [18]:
for x in new_weights:
    assert tuple(x[1]) == x[-1].shape

In [19]:
gpt.set_weights([x[-1] for x in new_weights])

In [20]:
from gpt2_tokenizer import GPT2Tokenizer

In [21]:
cbpe = GPT2Tokenizer(
    'CPM-Generate/bpe_3w_new/vocab.json',
    'CPM-Generate/bpe_3w_new/merges.txt',
    model_file='CPM-Generate/bpe_3w_new/chinese_vocab.model')

In [22]:
ids = cbpe.encode('今天天气不错')

for i in range(10):
    output = gpt(tf.constant([ids]))
    nid = np.argmax(output[0, -1])
    ids += [nid]
    print(i, cbpe.decode(ids))

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.524 seconds.
Prefix dict has been built successfully.


0 今天天气 不错 
1 今天天气 不错 ,
2 今天天气 不错 , 我
3 今天天气 不错 , 我 就
4 今天天气 不错 , 我 就 去
5 今天天气 不错 , 我 就 去 了
6 今天天气 不错 , 我 就 去 了 
7 今天天气 不错 , 我 就 去 了 。
8 今天天气 不错 , 我 就 去 了 。 
9 今天天气 不错 , 我 就 去 了 。  


In [243]:
@tf.function
def batch_gather(a, b):
    return tf.gather(a, b, batch_dims=1)


@tf.function
def top_k_top_p_sample(logits, num_samples=1, top_k=0, p=0.95):
    batch_size, vocab_size = logits.shape
    probs = tf.nn.softmax(logits, axis=-1)
    
    # [batch_size, vocab_perm]
    indices = tf.argsort(probs, direction='DESCENDING')
    logits_to_use = batch_gather(logits, indices)
    cumulative_probabilities = tf.math.cumsum(batch_gather(probs, indices), axis=-1, exclusive=False)

    # find the top pth index to cut off. careful we don't want to cutoff everything!
    # result will be [batch_size, vocab_perm]
    if p > 0.0:
        exclude_mask = tf.logical_not(
            tf.logical_or(cumulative_probabilities < p, tf.range(vocab_size)[None] < 1))
        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = logits_to_use - tf.cast(exclude_mask, tf.float32) * 1e10
    
    if top_k > 0:
        logits_to_use = logits_to_use - tf.cast(
            tf.argsort(logits_to_use, direction='DESCENDING') >= top_k,
            dtype=tf.float32
        ) * 1e10
    
    sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
    sample = batch_gather(indices, sample_perm)

    return tf.cast(sample, tf.int64)

In [244]:
@tf.function
def serve(inputs):
    return gpt(inputs, kv_cache=None, use_cache=True)


@tf.function
def serve_cache(inputs, kv_cache):
    return gpt(inputs, kv_cache=kv_cache, use_cache=True)

In [245]:
serve_concrete = serve.get_concrete_function(
    tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp")
)

layer_size = 12
attention_head = 12
embedding_size = 768

serve_cache_concrete = serve_cache.get_concrete_function(
    tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp"),
    tf.TensorSpec(shape=[
        layer_size, None, 2, attention_head,
        None, embedding_size // attention_head
    ], dtype=tf.float32, name="kv_cache")
)

In [246]:
r = serve_concrete(
    tf.constant([[1]], tf.int64)
)

In [247]:
print(r[0].shape, r[1].shape)

(1, 1, 30000) (12, 1, 2, 12, 1, 64)


In [248]:
r2 = serve_cache_concrete(
    tf.constant([[1]], tf.int64),
    r[1]
)

In [249]:
print(r2[0].shape, r2[1].shape)

(1, 1, 30000) (12, 1, 2, 12, 1, 64)


In [250]:
@tf.function
def sample(initial_inputs, length, top_k, top_p, temperature):
    layer_size = 12
    embedding_size = 768
    attention_head = 12

    i = tf.constant(0, dtype=tf.int64)
    initial_logits, kv_cache = serve(initial_inputs)
    logits_with_temperature = initial_logits[:, -1, :]
    if temperature > 0.0:
        logits_with_temperature /= temperature
    inputs = top_k_top_p_sample(logits_with_temperature, 1, top_k, top_p)
    stores = tf.concat([initial_inputs, inputs], axis=1)

    def _cond(i, inputs, kv_cache, stores):
        return i < length

    def _body(i, inputs, kv_cache, stores):
        new_logits, new_kv_cache = serve_cache(inputs, kv_cache)
        logits_with_temperature = new_logits[:, -1, :]
        if temperature > 0.0:
            logits_with_temperature /= temperature
        new_inputs = top_k_top_p_sample(logits_with_temperature, 1, top_k, top_p)
        new_stores = tf.concat([stores, new_inputs], axis=-1)
        new_kv_cache = tf.concat([
            kv_cache,
            new_kv_cache
        ], axis=-2)
        new_i = i + 1
        return [new_i, new_inputs, new_kv_cache, new_stores]

    result = tf.while_loop(
        _cond, _body,
        loop_vars=[i, inputs, kv_cache, stores],
        shape_invariants=[
            tf.TensorShape(None),
            tf.TensorShape([None, None]),
            tf.TensorShape([
                layer_size, None, 2,
                attention_head, None,
                embedding_size // attention_head
            ]),
            tf.TensorShape([
                None, None
            ])
        ]
    )
    return result[-1]

In [263]:
ids = cbpe.encode('今天天气不错')

ret = sample(
    tf.constant([ids], dtype=tf.int64),
    tf.constant(15, dtype=tf.int64),
    tf.constant(1, dtype=tf.int32),
    tf.constant(1.0, dtype=tf.float32),
    tf.constant(1.0, dtype=tf.float32)
)
print(ret)
print(cbpe.decode(ret.numpy().tolist()[0]))

tf.Tensor(
[[837 259 497 788   8   9  16  29  91  14   8  12   8  10  16  18   8 616
   89 219]], shape=(1, 20), dtype=int64)
今天天气 不错 , 我 就 去 了 。   我 在 院子 里


In [270]:
ids = cbpe.encode('今天天气不错')

ret = sample(
    tf.constant([ids], dtype=tf.int64),
    tf.constant(15, dtype=tf.int64),
    tf.constant(10, dtype=tf.int32),
    tf.constant(0.95, dtype=tf.float32),
    tf.constant(1.0, dtype=tf.float32)
)
print(ret)
print(cbpe.decode(ret.numpy().tolist()[0]))

tf.Tensor(
[[ 837  259  497  788    8    9   51   35   25  182   11 4588   97  124
    46   63    8    9  148   16]], shape=(1, 20), dtype=int64)
今天天气 不错 , 但是 这次 的 天气 并 不是 很 好 , 所以 我


In [271]:
gpt.save('./cpm-distill-tf2', include_optimizer=False, signatures={
    'serving_default': sample.get_concrete_function(
        tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp"),
        tf.TensorSpec(shape=[None,], dtype=tf.int64, name="length"),
        tf.TensorSpec(shape=[None,], dtype=tf.int32, name="top_k"),
        tf.TensorSpec(shape=[None,], dtype=tf.float32, name="top_p"),
        tf.TensorSpec(shape=[None,], dtype=tf.float32, name="temperature")
    )
})



INFO:tensorflow:Assets written to: ./cpm-distill-tf2/assets


INFO:tensorflow:Assets written to: ./cpm-distill-tf2/assets


In [272]:
!du -sh ./cpm-distill-tf2

424M	./cpm-distill-tf2
