In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import custom_object_scope
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras import layers, models, losses, callbacks

import pandas as pd
import numpy as np
import configparser
import os
from utils.process import main
data = pd.read_csv('../sampled_data.csv').sample(25000)
train_ds, val_ds, test_ds, combined_vocab = main(data)

#### i. General Generative

In [None]:
from generative.general_generative.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.general_generative.train import train_model, TrainTextGenerator, CustomSchedule
from generative.general_generative.evaluate import TextGenerator, CustomSchedule

# Read config file
config = configparser.ConfigParser()
config.read('./generative/config.ini')
params = config["params"]
epochs = int(params['epochs']) 

# Paths and Flags
LOAD_MODEL = False
MODEL_PATH = './models/general_generative/model_1.h5'
# Load or train the model
if LOAD_MODEL and os.path.exists(MODEL_PATH):
    model = train_model(preload_model=True, model_path=MODEL_PATH)
else:
    model = train_model(preload_model=False, model_path=MODEL_PATH)

def get_callbacks():
    model_checkpoint_callback = ModelCheckpoint(
        filepath="./models/general_generative/weights.{epoch:02d}-{val_loss:.2f}.ckpt",
        save_weights_only=False,
        save_best_only=True,
        monitor='val_loss',                                     
        verbose=1
    )
    text_generator = TrainTextGenerator(index_to_word=combined_vocab)
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    )
    return [model_checkpoint_callback, text_generator, early_stopping_callback]

# Train the model
model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=get_callbacks(),
)
model.save(MODEL_PATH)

In [None]:
from generative.general_generative.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.general_generative.evaluate import TextGenerator, CustomSchedule
MODEL_PATH = './models/general_generative/model_1.h5'
with custom_object_scope({'CustomSchedule': CustomSchedule, 'TransformerBlock': TransformerBlock, 'TokenAndPositionEmbedding': TokenAndPositionEmbedding}):
    model = load_model(MODEL_PATH)

test_text_gen = TextGenerator(model=model, index_to_word=combined_vocab, top_k=15, generation_type='general', sampling_type='top_k')

info = test_text_gen.generate("Today in the news", max_tokens=50, temperature=1.0)

#### ii. Custom Generative

In [None]:
from generative.custom_generative.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.custom_generative.train import train_model, TrainTextGenerator, CustomSchedule

# Read config file
config = configparser.ConfigParser()
config.read('./generative/config.ini')
params = config["params"]
epochs = int(params['epochs']) 

# Paths and Flags
LOAD_MODEL = True
MODEL_PATH = './models/custom_generative/model_1.h5'
# Load or train the model
if LOAD_MODEL and os.path.exists(MODEL_PATH):
    model = train_model(preload_model=True, model_path=MODEL_PATH)
else:
    model = train_model(preload_model=False, model_path=MODEL_PATH)

def get_callbacks():
    model_checkpoint_callback = ModelCheckpoint(
        filepath="./models/custom_generative/weights.{epoch:02d}-{val_loss:.2f}.ckpt",
        save_weights_only=False,
        save_best_only=True,
        monitor='val_loss',                                     
        verbose=1
    )
    text_generator = TrainTextGenerator(combined_vocab=combined_vocab)
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    )
    return [model_checkpoint_callback, text_generator, early_stopping_callback]

# Train the model
model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=get_callbacks(),
)

model.save(MODEL_PATH)

In [None]:
from generative.custom_generative.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.custom_generative.train import train_model, TrainTextGenerator, CustomSchedule
from generative.custom_generative.evaluate import TextGenerator
# Add it to your model's callbacks
MODEL_PATH = './models/custom_generative/model_1.h5'
# Load the model
model = load_model(MODEL_PATH, custom_objects={
    "TransformerBlock": TransformerBlock,
    "TokenAndPositionEmbedding": TokenAndPositionEmbedding,
    "CustomSchedule": CustomSchedule})

test_text_gen = TextGenerator(model, vocab=combined_vocab, sampling_type='top_p', generation_type='general')
# Evaluate the model
model.evaluate(train_ds, callbacks=[test_text_gen])