In [2]:
from tfchat.configs import Config
from tfchat.data import BlockDataset
from tfchat.metrics import perplexity
from tfchat.losses import PaddingLoss
from tfchat.schedules import WarmupLinearDecay
from tfchat.generations import TopKTopPGenerator
from tfchat.models import PreLNDecoder

import tensorflow.keras as keras
import numpy as np


# Define model config
config = Config(num_layers=6, d_model=64, num_heads=1, d_ff=256, vocab_size=100,
                context_size=64, attention_dropout_rate=0.1, residual_dropout_rate=0.1,
                embedding_dropout_rate=0.1, epsilon=1e-06)

# You can use predefined config as follows instead of defining config by yourself
#
# from tfchat.configs import GPT2SmallConfig
# config = GPT2SmallConfig()


# Define training parameters
batch_size = 2
epochs = 10

# Prepare dataset
train_ids = np.tile(np.arange(10, dtype=np.int32), 1000)  # Prepare token ids for training data
valid_ids = np.tile(np.arange(10, dtype=np.int32), 100)   # Prepare token ids for validation data

dataset = BlockDataset(block_size=config.context_size, batch_size=batch_size)
train_dataset = dataset.build(train_ids, shuffle=True)
valid_dataset = dataset.build(valid_ids, shuffle=False)

# Prepare model
num_steps = len([_ for _ in train_dataset])
schedule = WarmupLinearDecay(max_learning_rate=1e-3, warmup_steps=0, training_steps=num_steps*epochs)
optimizer = keras.optimizers.Adam(schedule, beta_1=0.9, beta_2=0.999, epsilon=1e-8, clipnorm=1.0)

model = PreLNDecoder(config)
model.compile(loss=PaddingLoss(), optimizer=optimizer)
model.build(input_shape=(None, config.context_size))
model.summary()

# Train
history = model.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs=epochs,
    callbacks=[
        keras.callbacks.EarlyStopping(patience=1, restore_best_weights=True),
        # If you want to save chekcpoints, remove the next comment out
        #keras.callbacks.ModelCheckpoint("keras_model/", save_best_only=True)
    ]
)

# Evaluate
ppl = perplexity(model, valid_dataset)
print("Validation PPL:", ppl)

# Generate
gen = TopKTopPGenerator(model=model, max_len=3)
inputs = np.array([[1, 2, 3, 4, 5]], dtype=np.int32)
gen.generate(inputs)

Model: "pre_ln_decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder (Decoder)            multiple                  312804    
Total params: 312,804
Trainable params: 312,804
Non-trainable params: 0
_________________________________________________________________
Epoch 1/10
Epoch 2/10
{'loss': 9.0920184e-05, 'perplexity': 1.000091, 'num_batches': 7, 'num_tokens': 807}
Validation PPL: 1.000091


array([[1, 2, 3, 4, 5, 6, 7, 8]])