In [1]:
import re

import torch
import numpy as np
import tensorflow as tf

from tf2gpt.model import GPT

In [2]:
m0 = torch.load('./model-v1/80000/mp_rank_00_model_states.pt', map_location='cpu')
m1 = torch.load('./model-v1/80000/mp_rank_01_model_states.pt', map_location='cpu')

In [3]:
def find_weight(model, name):
    for n, w in model['module'].items():
        if name == n:
            return w, list(w.shape)

In [4]:
for n, w in m0['module'].items():
    if '.layers.' in n:
        if '.layers.0.' in n:
            print(n, w.shape)
    else:
        print(n, w.shape)

# for n, w in m1['module'].items():
#     print(n, w.shape)

word_embeddings.weight torch.Size([15000, 2560])
position_embeddings.weight torch.Size([1024, 2560])
transformer.layers.0.input_layernorm.weight torch.Size([2560])
transformer.layers.0.input_layernorm.bias torch.Size([2560])
transformer.layers.0.attention.query_key_value.weight torch.Size([3840, 2560])
transformer.layers.0.attention.query_key_value.bias torch.Size([3840])
transformer.layers.0.attention.dense.weight torch.Size([2560, 1280])
transformer.layers.0.attention.dense.bias torch.Size([2560])
transformer.layers.0.post_attention_layernorm.weight torch.Size([2560])
transformer.layers.0.post_attention_layernorm.bias torch.Size([2560])
transformer.layers.0.mlp.dense_h_to_4h.weight torch.Size([5120, 2560])
transformer.layers.0.mlp.dense_h_to_4h.bias torch.Size([5120])
transformer.layers.0.mlp.dense_4h_to_h.weight torch.Size([2560, 5120])
transformer.layers.0.mlp.dense_4h_to_h.bias torch.Size([2560])
transformer.final_layernorm.weight torch.Size([2560])
transformer.final_layernorm.bia

In [5]:
# tf.keras.backend.set_floatx('float16')

In [6]:
gpt = GPT(
    vocab_size=30_000,
    layer_size=32,
    block_size=1024,
    embedding_dropout=0.0,
    embedding_size=2560,
    num_attention_heads=32,
    attention_dropout=0.0,
    residual_dropout=0.0)

In [7]:
strs = tf.TensorSpec(shape=[None, None],
                     dtype=tf.int32,
                     name="input_strs")

gpt._set_inputs(strs)

In [8]:
print(gpt(tf.constant([[1]])).shape)

(1, 1, 30000)


In [9]:
# for x in gpt.weights:
#     assert x.dtype == tf.float16

In [10]:
for x in gpt.weights:
    if 'gpt/layer' in x.name:
        if 'gpt/layer00' in x.name:
            print(x.name, x.shape)
    else:
        print(x.name, x.shape)

gpt/embedding/embeddings:0 (30000, 2560)
position_embeddings:0 (1024, 2560)
gpt/layer00/attention/query_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/query_layer/bias:0 (2560,)
gpt/layer00/attention/key_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/key_layer/bias:0 (2560,)
gpt/layer00/attention/value_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/value_layer/bias:0 (2560,)
gpt/layer00/attention/context_projection_layer/kernel:0 (2560, 2560)
gpt/layer00/attention/context_projection_layer/bias:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln0/gamma:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln0/beta:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln1/gamma:0 (2560,)
gpt/layer00/LayerNorm_mlp_ln1/beta:0 (2560,)
gpt/layer00/intermediate/kernel:0 (2560, 10240)
gpt/layer00/intermediate/bias:0 (10240,)
gpt/layer00/output/kernel:0 (10240, 2560)
gpt/layer00/output/bias:0 (2560,)
gpt/LayerNorm_final_norm/gamma:0 (2560,)
gpt/LayerNorm_final_norm/beta:0 (2560,)


In [11]:
new_weights = []

for x in gpt.weights:
    xs = list(x.shape)

    if 'gpt/embedding/embeddings:' in x.name:
        pname = 'word_embeddings.weight'
        w0, ws0 = find_weight(m0, pname)
        w1, ws1 = find_weight(m1, pname)
        assert ws0 == [15_000, 2560]
        assert ws1 == [15_000, 2560]
        w = np.concatenate([w0.numpy(), w1.numpy()])
        assert w.shape == (3_0000, 2560)
        new_weights.append((x.name, xs, pname, w))

    elif 'position_embeddings' in x.name:
        pname = 'position_embeddings.weight'
        w0, ws0 = find_weight(m0, pname)
        assert xs == ws0
        w = w0.numpy()
        new_weights.append((x.name, xs, pname, w))

    elif 'gpt/layer' in x.name:
        n_layer = int(x.name[len('gpt/layer'):][:2])
        if 'query_layer' in x.name or 'key_layer' in x.name or 'value_layer' in x.name:

            if '/kernel' in x.name:
                pname = f'transformer.layers.{n_layer}.attention.query_key_value.weight'
                w0, ws0 = find_weight(m0, pname)
                w1, ws1 = find_weight(m1, pname)
                assert ws0 == [3840, 2560]
                assert ws1 == [3840, 2560]
                w = np.concatenate([w0.numpy(), w1.numpy()])
                w = np.transpose(w)
                if 'query_layer' in x.name:
                    w = np.concatenate([w0.numpy()[:1280, :], w1.numpy()[:1280, :]])
                elif 'key_layer' in x.name:
                    w = np.concatenate([w0.numpy()[1280:1280*2, :], w1.numpy()[1280:1280*2, :]])
                elif 'value_layer' in x.name:
                    w = np.concatenate([w0.numpy()[1280*2:, :], w1.numpy()[1280*2:, :]])
                w = np.transpose(w)
                assert w.shape == (2560, 2560)
                new_weights.append((x.name, xs, pname, w))

            elif '/bias' in x.name:
                pname = f'transformer.layers.{n_layer}.attention.query_key_value.bias'
                w0, ws0 = find_weight(m0, pname)
                w1, ws1 = find_weight(m1, pname)
                assert ws0 == [3840,]
                assert ws1 == [3840,]
                w = np.concatenate([w0.numpy(), w1.numpy()])
                if 'query_layer' in x.name:
                    w = np.concatenate([w0.numpy()[:1280], w1.numpy()[:1280]])
                elif 'key_layer' in x.name:
                    w = np.concatenate([w0.numpy()[1280:1280*2], w1.numpy()[1280:1280*2]])
                elif 'value_layer' in x.name:
                    w = np.concatenate([w0.numpy()[1280*2:], w1.numpy()[1280*2:]])
                assert w.shape == (2560,)
                new_weights.append((x.name, xs, pname, w))

        elif 'attention/context_projection_layer/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.dense.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=1)
            w = np.transpose(w)
            assert w.shape == (2560, 2560)
            new_weights.append((x.name, xs, pname, w))

        elif 'attention/context_projection_layer/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.attention.dense.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert w.shape == (2560,)
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln0/gamma' in x.name:
            pname = f'transformer.layers.{n_layer}.input_layernorm.weight'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, x.shape, pname, w))

        elif 'LayerNorm_mlp_ln1/gamma' in x.name:
            pname = f'transformer.layers.{n_layer}.post_attention_layernorm.weight'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln0/beta' in x.name:
            pname = f'transformer.layers.{n_layer}.input_layernorm.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'LayerNorm_mlp_ln1/beta' in x.name:
            pname = f'transformer.layers.{n_layer}.post_attention_layernorm.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert ws0 == xs
            new_weights.append((x.name, xs, pname, w))

        elif 'intermediate/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=0)
            w = np.transpose(w)
            assert w.shape == (2560, 10240)
            new_weights.append((x.name, xs, pname, w))

        elif 'intermediate/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_h_to_4h.bias'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=0)
            w = np.transpose(w)
            assert w.shape == (10240,)
            new_weights.append((x.name, xs, pname, w))

        elif '/output/kernel' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_4h_to_h.weight'
            w0, ws0 = find_weight(m0, pname)
            w1, ws1 = find_weight(m1, pname)
            w = np.concatenate([w0.numpy(), w1.numpy()], axis=1)
            w = np.transpose(w)
            assert w.shape == (10240, 2560)
            new_weights.append((x.name, xs, pname, w))

        elif '/output/bias' in x.name:
            pname = f'transformer.layers.{n_layer}.mlp.dense_4h_to_h.bias'
            w0, ws0 = find_weight(m0, pname)
            w = w0.numpy()
            assert w.shape == (2560,)
            new_weights.append((x.name, xs, pname, w))

        else:
            print('BAD', x.name, xs)
            break
    elif 'gpt/LayerNorm_final_norm/gamma' in x.name:
        pname = 'transformer.final_layernorm.weight'
        w0, ws0 = find_weight(m0, pname)
        w = w0.numpy()
        assert ws0 == xs
        new_weights.append((x.name, xs, pname, w))
    elif 'gpt/LayerNorm_final_norm/beta' in x.name:
        pname = 'transformer.final_layernorm.bias'
        w0, ws0 = find_weight(m0, pname)
        w = w0.numpy()
        assert ws0 == xs
        new_weights.append((x.name, xs, pname, w))
    else:
        print('BAD', x.name, xs)
        break

In [12]:
assert len(new_weights) == len(gpt.weights)

In [13]:
for x in new_weights:
    assert tuple(x[1]) == x[-1].shape

In [14]:
gpt.set_weights([x[-1] for x in new_weights])

In [15]:
gpt.save('./cpm-logits-model')

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ./cpm-logits-model/assets
