In [1]:
import re

import torch
import numpy as np
import tensorflow as tf

from tf2gpt.model import GPT

In [2]:
m0 = torch.load('./model-v1/80000/mp_rank_00_model_states.pt', map_location='cpu')
m1 = torch.load('./model-v1/80000/mp_rank_01_model_states.pt', map_location='cpu')

In [3]:
def find_weight(model, name):
    for n, w in model['module'].items():
        if name == n:
            return w, list(w.shape)

In [4]:
for n, w in m0['module'].items():
    if '.layers.' in n:
        if '.layers.0.' in n:
            print(n, w.shape)
    else:
        print(n, w.shape)

# for n, w in m1['module'].items():
#     print(n, w.shape)

word_embeddings.weight torch.Size([15000, 2560])
position_embeddings.weight torch.Size([1024, 2560])
transformer.layers.0.input_layernorm.weight torch.Size([2560])
transformer.layers.0.input_layernorm.bias torch.Size([2560])
transformer.layers.0.attention.query_key_value.weight torch.Size([3840, 2560])
transformer.layers.0.attention.query_key_value.bias torch.Size([3840])
transformer.layers.0.attention.dense.weight torch.Size([2560, 1280])
transformer.layers.0.attention.dense.bias torch.Size([2560])
transformer.layers.0.post_attention_layernorm.weight torch.Size([2560])
transformer.layers.0.post_attention_layernorm.bias torch.Size([2560])
transformer.layers.0.mlp.dense_h_to_4h.weight torch.Size([5120, 2560])
transformer.layers.0.mlp.dense_h_to_4h.bias torch.Size([5120])
transformer.layers.0.mlp.dense_4h_to_h.weight torch.Size([2560, 5120])
transformer.layers.0.mlp.dense_4h_to_h.bias torch.Size([2560])
transformer.final_layernorm.weight torch.Size([2560])
transformer.final_layernorm.bia

In [5]:
# tf.keras.backend.set_floatx('float16')

In [6]:
gpt = GPT(
    vocab_size=30_000,
    layer_size=32,
    block_size=1024,
    embedding_dropout=0.0,
    embedding_size=2560,
    num_attention_heads=32,
    attention_dropout=0.0,
    residual_dropout=0.0)

In [7]:
print(gpt(tf.constant([[1]])).shape)

(1, 1, 30000)


In [8]:
# for x in gpt.weights:
#     assert x.dtype == tf.float16

In [9]:
for x in gpt.weights:
    if 'gpt/layer' in x.name:
        if 'gpt/layer00' in x.name:
            print(x.name, x.shape)
    else:
        print(x.name, x.shape)

gpt/embedding/embeddings:0 (30000, 2560)
position_embeddings:0 (1024, 2560)
gpt/layer00/attention/query_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/query_layer/bias:0 (2560,)
gpt/layer00/attention/key_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/key_layer/bias:0 (2560,)
gpt/layer00/attention/value_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/value_layer/bias:0 (2560,)
gpt/layer00/attention/context_projection_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/context_projection_layer/bias:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln0/gamma:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln0/beta:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln1/gamma:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln1/beta:0 (2560,)
gpt/layer00/intermediate/kernel:0 (2560, 10240)
gpt/layer00/intermediate/bias:0 (10240,)
gpt/layer00/output/kernel:0 (10240, 2560)
gpt/layer00/output/bias:0 (2560,)
gpt/LayerNorm_final_norm/gamma:0 (2560,)
gpt/LayerNorm_final_norm/beta:0 (2560,)


In [10]:
new_weights = []

for x in gpt.weights:
    xs = list(x.shape)

    if 'gpt/embedding/embeddings:' in x.name:
        pname = 'word_embeddings.weight'
        w0, ws0 = find_weight(m0, pname)
        w1, ws1 = find_weight(m1, pname)
        assert ws0 == [15_000, 2560]
        assert ws1 == [15_000, 2560]
        w = np.concatenate([w0.numpy(), w1.numpy()])
        assert w.shape == (3_0000, 2560)
        new_weights.append((x.name, xs, pname, w))

    elif 'position_embeddings' in x.name:
        pname = 'position_embeddings.weight'
        w0, ws0 = find_weight(m0, pname)
        assert xs == ws0
        w = w0.numpy()
        new_weights.append((x.name, xs, pname, w))

    elif 'gpt/layer' in x.name:
        n_layer = int(x.name[len('gpt/layer'):][:2])
        if 'query_layer' in x.name or 'key_layer' in x.name or 'value_layer' in x.name:

            if '/kernel' in x.name:
                pname = f'transformer.layers.{n_layer}.attention.query_key_value.weight'
                w0, ws0 = find_weight(m0, pname)
                w1, ws1 = find_weight(m1, pname)
                assert ws0 == [3840, 2560]
                assert ws1 == [3840, 2560]
                w = np.concatenate([w0.numpy(), w1.numpy()])
                w = np.transpose(w)
                if 'query_layer' in x.name:
                    w = np.concatenate([w0.numpy()[:1280, :], w1.numpy()[:1280, :]])
                elif 'key_layer' in x.name:
                    w = np.concatenate([w0.numpy()[1280:1280*2, :], w1.numpy()[1280:1280*2, :]])
                elif 'value_layer' in x.name:
                    w = np.concatenate([w0.numpy()[1280*2:, :], w1.numpy()[1280*2:, :]])
                w = np.transpose(w)
                assert w.shape == (2560, 2560)
                new_weights.append((x.name, xs, pname, w))

            elif '/bias' in x.name:
                pname = f'transformer.layers.{n_layer}.attention.query_key_value.bias'
                w0, ws0 = find_weight(m0, pname)
                w1, ws1 = find_weight(m1, pname)
                assert ws0 == [3840,]
                assert ws1 == [3840,]
                w = np.concatenate([w0.numpy(), w1.numpy()])
                if 'query_layer' in x.name:
                    w = np.concatenate([w0.numpy()[:1280], w1.numpy()[:1280]])
                elif 'key_layer' in x.name:
                    w = np.concatenate([w0.numpy()[1280:1280*2], w1.numpy()[1280:1280*2]])
                elif 'value_layer' in x.name:
                    w = np.concatenate([w0.numpy()[1280*2:], w1.numpy()[1280*2:]])
                assert w.shape == (2560,)
                new_weights.append((x.name, xs, pname, w))

        elif 'attention/context_projection_layer/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.dense.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=1)
            w = np.transpose(w)
            assert w.shape == (2560, 2560)
            new_weights.append((x.name, xs, pname, w))

        elif 'attention/context_projection_layer/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.dense.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert w.shape == (2560,)
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln0/gamma' in x.name:
            pname = f'transformer.layers.{n_layer}.input_layernorm.weight'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif 'LayerNorm_mlp_ln1/gamma' in x.name:
            pname = f'transformer.layers.{n_layer}.post_attention_layernorm.weight'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln0/beta' in x.name:
            pname = f'transformer.layers.{n_layer}.input_layernorm.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln1/beta' in x.name:
            pname = f'transformer.layers.{n_layer}.post_attention_layernorm.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'intermediate/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=0)
            w = np.transpose(w)
            assert w.shape == (2560, 10240)
            new_weights.append((x.name, xs, pname, w))

        elif 'intermediate/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.bias'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=0)
            w = np.transpose(w)
            assert w.shape == (10240,)
            new_weights.append((x.name, xs, pname, w))

        elif '/output/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_4h_to_h.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=1)
            w = np.transpose(w)
            assert w.shape == (10240, 2560)
            new_weights.append((x.name, xs, pname, w))

        elif '/output/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_4h_to_h.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert w.shape == (2560,)
            new_weights.append((x.name, xs, pname, w))

        else:
            print('BAD', x.name, xs)
            break
    elif 'gpt/LayerNorm_final_norm/gamma' in x.name:
        pname = 'transformer.final_layernorm.weight'
        w0, ws0 = find_weight(m0, pname)
        w = w0.numpy()
        assert ws0 == xs
        new_weights.append((x.name, xs, pname, w))
    elif 'gpt/LayerNorm_final_norm/beta' in x.name:
        pname = 'transformer.final_layernorm.bias'
        w0, ws0 = find_weight(m0, pname)
        w = w0.numpy()
        assert ws0 == xs
        new_weights.append((x.name, xs, pname, w))
    else:
        print('BAD', x.name, xs)
        break

In [11]:
assert len(new_weights) == len(gpt.weights)

In [12]:
for x in new_weights:
    assert tuple(x[1]) == x[-1].shape

In [13]:
gpt.set_weights([x[-1] for x in new_weights])

In [14]:
from gpt2_tokenizer import GPT2Tokenizer

In [15]:
cbpe = GPT2Tokenizer(
    'CPM-Generate/bpe_3w_new/vocab.json',
    'CPM-Generate/bpe_3w_new/merges.txt',
    model_file='CPM-Generate/bpe_3w_new/chinese_vocab.model')

In [16]:
ids = cbpe.encode('今天天气不错')

for i in range(10):
    output = gpt(tf.constant([ids]))
    nid = np.argmax(output[0, -1])
    ids += [nid]
    print(i, cbpe.decode(ids))

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.623 seconds.
Prefix dict has been built successfully.


0 今天天气不错
1 今天天气不错,
2 今天天气不错,我
3 今天天气不错,我想
4 今天天气不错,我想去
5 今天天气不错,我想去看看
6 今天天气不错,我想去看看
7 今天天气不错,我想去看看。
8 今天天气不错,我想去看看。
9 今天天气不错,我想去看看。”


In [17]:
print(ids)

[837, 259, 497, 788, 8, 9, 16, 84, 91, 881, 8, 12, 8, 34]


In [31]:
def batch_gather(a, b):
    return tf.gather(a, b, batch_dims=1)


def top_p_sample(logits, num_samples=1, p=0.95):
    batch_size, vocab_size = logits.shape
    probs = tf.nn.softmax(logits, axis=-1)
    # [batch_size, vocab_perm]
    indices = tf.argsort(probs, direction='DESCENDING')
    cumulative_probabilities = tf.math.cumsum(batch_gather(probs, indices), axis=-1, exclusive=False)

    # find the top pth index to cut off. careful we don't want to cutoff everything!
    # result will be [batch_size, vocab_perm]
    exclude_mask = tf.logical_not(
        tf.logical_or(cumulative_probabilities < p, tf.range(vocab_size)[None] < 1))

    # OPTION A - sample in the sorted space, then unsort.
    logits_to_use = batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
    sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
    sample = batch_gather(indices, sample_perm)

    return tf.cast(sample, tf.int64)

In [19]:
@tf.function
def serve(inputs):
    return gpt(inputs, kv_cache=None, use_cache=True)


@tf.function
def serve_cache(inputs, kv_cache):
    return gpt(inputs, kv_cache=kv_cache, use_cache=True)

In [20]:
serve_concrete = serve.get_concrete_function(
    tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp")
)

layer_size = 32
attention_head = 32
embedding_size = 2560

serve_cache_concrete = serve_cache.get_concrete_function(
    tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp"),
    tf.TensorSpec(shape=[
        layer_size, None, 2, attention_head,
        None, embedding_size // attention_head
    ], dtype=tf.float32, name="kv_cache")
)

In [21]:
r = serve_concrete(
    tf.constant([[1]], tf.int64)
)

In [22]:
print(r[0].shape, r[1].shape)

(1, 1, 30000) (32, 1, 2, 32, 1, 80)


In [23]:
r2 = serve_cache_concrete(
    tf.constant([[1]], tf.int64),
    r[1]
)

In [24]:
print(r2[0].shape, r2[1].shape)

(1, 1, 30000) (32, 1, 2, 32, 1, 80)


In [33]:
@tf.function
def sample(initial_inputs, length, top_p, temperature):
    layer_size = 32
    embedding_size = 2560
    attention_head = 32

    i = tf.constant(0, dtype=tf.int64)
    initial_logits, kv_cache = serve(initial_inputs)
    inputs = top_p_sample(initial_logits[:, -1, :] / temperature, 1, top_p)
    stores = tf.concat([initial_inputs, inputs], axis=1)

    def _cond(i, inputs, kv_cache, stores):
        return i < length

    def _body(i, inputs, kv_cache, stores):
        new_logits, new_kv_cache = serve_cache(inputs, kv_cache)
        
        new_inputs = top_p_sample(new_logits[:, -1, :] / temperature, 1, top_p)
        new_stores = tf.concat([stores, new_inputs], axis=-1)
        new_kv_cache = tf.concat([
            kv_cache,
            new_kv_cache
        ], axis=-2)
        new_i = i + 1
        return [new_i, new_inputs, new_kv_cache, new_stores]

    result = tf.while_loop(
        _cond, _body,
        loop_vars=[i, inputs, kv_cache, stores],
        shape_invariants=[
            tf.TensorShape(None),
            tf.TensorShape([None, None]),
            tf.TensorShape([
                layer_size, None, 2,
                attention_head, None,
                embedding_size // attention_head
            ]),
            tf.TensorShape([
                None, None
            ])
        ]
    )
    return result[-1]

In [34]:
ids = cbpe.encode('今天天气不错')

ret = sample(
    tf.constant([ids], dtype=tf.int64),
    tf.constant(15, dtype=tf.int64),
    tf.constant(0.95, dtype=tf.float32),
    tf.constant(0.9, dtype=tf.float32)
)
print(ret)
print(cbpe.decode(ret.numpy().tolist()[0]))

tf.Tensor(
[[  837   259   497   788     8     9    86    26    29    24 14086    14
      8    12     8    10   366  1639     8     9]], shape=(1, 20), dtype=int64)
今天天气不错,她也就不计较了。 至此,


In [35]:
gpt.save('./cpm-lm-tf2_v2', include_optimizer=False, signatures={
    'serving_default': sample.get_concrete_function(
        tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="inp"),
        tf.TensorSpec(shape=[None,], dtype=tf.int64, name="length"),
        tf.TensorSpec(shape=[None,], dtype=tf.float32, name="top_p"),
        tf.TensorSpec(shape=[None,], dtype=tf.float32, name="temperature")
    )
})

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ./cpm-lm-tf2_v2/assets


In [36]:
!du -sh ./cpm-lm-tf2_v2

9.8G	./cpm-lm-tf2_v2
