In [1]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from generative.process import main
from generative.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.train import train_model, TrainTextGenerator, CustomSchedule
import pandas as pd
import configparser
import os

# Load and process the data
data = pd.read_parquet('./sampled_0623_1023.parquet').sample(50000)
train_ds, val_ds, test_ds, combined_vocab = main(data)

# Read config file
config = configparser.ConfigParser()
config.read('./generative/config.ini')
params = config["params"]
epochs = int(params['epochs']) 

LOAD_MODEL = False
MODEL_PATH = './models/generative/model_2.h5'

# Load or train the model
if LOAD_MODEL and os.path.exists(MODEL_PATH):
    model = train_model(preload_model=True, model_path=MODEL_PATH)
else:
    model = train_model(preload_model=False, model_path=MODEL_PATH)

def get_callbacks():
    model_checkpoint_callback = ModelCheckpoint(
        filepath="./models/generative/weights.{epoch:02d}-{val_loss:.2f}.ckpt",
        save_weights_only=False,
        save_best_only=True,
        monitor='val_loss',                                     
        verbose=1
    )
    text_generator = TrainTextGenerator(index_to_word=combined_vocab)
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    )
    return [model_checkpoint_callback, text_generator, early_stopping_callback]

# Train the model
model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=get_callbacks(),
)
model.save(MODEL_PATH)

Epoch 1/50
Epoch 1: val_loss improved from inf to 1.33411, saving model to ./models/generative\weights.01-1.33.ckpt




INFO:tensorflow:Assets written to: ./models/generative\weights.01-1.33.ckpt\assets


INFO:tensorflow:Assets written to: ./models/generative\weights.01-1.33.ckpt\assets


Epoch 2/50
Epoch 2: val_loss improved from 1.33411 to 0.23219, saving model to ./models/generative\weights.02-0.23.ckpt




INFO:tensorflow:Assets written to: ./models/generative\weights.02-0.23.ckpt\assets


INFO:tensorflow:Assets written to: ./models/generative\weights.02-0.23.ckpt\assets


Epoch 3/50
Epoch 3: val_loss improved from 0.23219 to 0.20815, saving model to ./models/generative\weights.03-0.21.ckpt




INFO:tensorflow:Assets written to: ./models/generative\weights.03-0.21.ckpt\assets


INFO:tensorflow:Assets written to: ./models/generative\weights.03-0.21.ckpt\assets


Epoch 4/50
Epoch 4: val_loss did not improve from 0.20815
Epoch 5/50
Epoch 5: val_loss did not improve from 0.20815
Epoch 6/50
Epoch 6: val_loss did not improve from 0.20815
Epoch 7/50
Epoch 7: val_loss did not improve from 0.20815
Epoch 8/50
Epoch 8: val_loss did not improve from 0.20815


#### Evaluations

In [3]:
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import custom_object_scope
from generative.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.evaluate import TextGenerator, CustomSchedule
from generative.process import main

# Load fresh sample (better for evaluating model performance) and use the model to generate text  
data = pd.read_parquet('./sampled_0623_1023.parquet').sample(10000)
train_ds, val_ds, test_ds, combined_vocab = main(data)

In [None]:
with custom_object_scope({'CustomSchedule': CustomSchedule, 'TransformerBlock': TransformerBlock, 'TokenAndPositionEmbedding': TokenAndPositionEmbedding}):
    model_directory = './models/generative/model_2.h5'
    gpt = load_model(model_directory)

text_generator = TextGenerator(gpt, index_to_word=combined_vocab)
# input starter text
text_generator.generate('the', max_tokens=25, temperature=1)