In [None]:
# !wget https://code.aliyun.com/qhduan/zh-bert/raw/0fb1d96ec2133fe25e66bee12fe387cbe1e52938/vocab.txt
# !pip install tokenizers

In [None]:
import os
import math
import functools

os.environ['CUDA_VISIBLE_DEVICES'] = ''

import tensorflow as tf
import tensorflow_addons as tfa
from tokenizers import BertWordPieceTokenizer

from model import GPT

In [None]:
def get_learning_rate(learning_rate=6e-4,
                      warmup_steps=20_0000,
                      decay_steps=200_0000,
                      alpha=0.0):
    def decayed_learning_rate(step=1):
        if step <= warmup_steps:
            mult = step / float(warmup_steps)
        else:
            progress = (step - warmup_steps) / (decay_steps - warmup_steps)
            mult = 0.5 * (1 + math.cos(math.pi * progress))
            mult = max(0.1, mult)
        return learning_rate * mult
    return decayed_learning_rate


def data_generator(path, batch_size=4):
    batch = []
    with open(path) as fp:
        for line in fp:
            line = line.strip()
            if len(line) <= 0:
                continue
            x = [
                tokenizer.token_to_id(x)
                for x in line.split('\t')]
            batch.append(x)
            if len(batch) >= batch_size:
                batch = tf.ragged.constant(batch)
                batch = batch.to_tensor()
                yield batch[:, :-1], batch[:, 1:]
                batch = []
    if len(batch) > 0:
        batch = tf.ragged.constant(batch)
        batch = batch.to_tensor()
        yield batch[:, :-1], batch[:, 1:]
        batch = []

In [None]:
tokenizer = BertWordPieceTokenizer('vocab.txt')

gpt = GPT(vocab_size=tokenizer.get_vocab_size(),
          layer_size=12,
          block_size=1024,
          embedding_size=768,
          num_attention_heads=12,
          embedding_dropout=0.1,
          attention_dropout=0.1,
          residual_dropout=0.1)

gpt.compile(
    optimizer=tfa.optimizers.AdamW(
        weight_decay=0.1,
        learning_rate=get_learning_rate(),
        beta_1=0.9,
        beta_2=0.95,
        epsilon=1e-8,
        name='AdamW',
        clipnorm=1.0
    )
)

gpt._set_inputs(tf.keras.layers.Input((None,), dtype=tf.int32))

In [None]:
dataset = tf.data.Dataset.from_generator(
    functools.partial(
        data_generator,
        path='/home/qhduan/DATA10T/seagate/DATASETS/bd-chat/chat.txt'),
    output_types=(tf.int32, tf.int32),
    output_shapes=(tf.TensorShape([None, None]), tf.TensorShape([None, None]))
)

In [None]:
gpt.fit(dataset, epochs=10)

In [None]:
gpt.save('/tmp/gpt3')

In [None]:
# gpt_load = tf.keras.models.load_model('/tmp/gpt3')