In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import custom_object_scope
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard

from utils.process import main
from generative.general_generative.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.general_generative.train import train_model, TrainTextGenerator, CustomSchedule
from generative.general_generative.evaluate import TextGenerator, CustomSchedule
from generative.custom_generative.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.custom_generative.train import train_model, TrainTextGenerator, CustomSchedule


import pandas as pd
import numpy as np
import configparser
import os

#### i. General Generative

In [None]:
# Load and process the data
data = pd.read_csv('./sampled_data.csv').sample(25000)
train_ds, val_ds, test_ds, combined_vocab = main(data)

# Read config file
config = configparser.ConfigParser()
config.read('./generative/config.ini')
params = config["params"]
epochs = int(params['epochs']) 

# Paths and Flags
LOAD_MODEL = False
MODEL_PATH = './models/general_generative/model_2.h5'
# Load or train the model
if LOAD_MODEL and os.path.exists(MODEL_PATH):
    model = train_model(preload_model=True, model_path=MODEL_PATH)
else:
    model = train_model(preload_model=False, model_path=MODEL_PATH)

def get_callbacks():
    model_checkpoint_callback = ModelCheckpoint(
        filepath="./models/general_generative/weights.{epoch:02d}-{val_loss:.2f}.ckpt",
        save_weights_only=False,
        save_best_only=True,
        monitor='val_loss',                                     
        verbose=1
    )
    text_generator = TrainTextGenerator(index_to_word=combined_vocab)
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    )
    return [model_checkpoint_callback, text_generator, early_stopping_callback]

# Train the model
model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=get_callbacks(),
)
model.save(MODEL_PATH)

In [None]:
with custom_object_scope({'CustomSchedule': CustomSchedule, 'TransformerBlock': TransformerBlock, 'TokenAndPositionEmbedding': TokenAndPositionEmbedding}):
    gpt = load_model(MODEL_PATH)
text_generator = TextGenerator(model, index_to_word=combined_vocab)
# input starter text
text = text_generator.generate('Test', max_tokens=100, temperature=1)

#### ii. Custom Generative

In [None]:
# Load and process the data
data = pd.read_csv('./sampled_data.csv').sample(25000)
train_ds, val_ds, test_ds, combined_vocab = main(data)

# Read config file
config = configparser.ConfigParser()
config.read('./generative/config.ini')
params = config["params"]
epochs = int(params['epochs']) 

# Paths and Flags
LOAD_MODEL = False
MODEL_PATH = './models/custom_generative/model_2.h5'
# Load or train the model
if LOAD_MODEL and os.path.exists(MODEL_PATH):
    model = train_model(preload_model=True, model_path=MODEL_PATH)
else:
    model = train_model(preload_model=False, model_path=MODEL_PATH)

def get_callbacks():
    model_checkpoint_callback = ModelCheckpoint(
        filepath="./models/custom_generative/weights.{epoch:02d}-{val_loss:.2f}.ckpt",
        save_weights_only=False,
        save_best_only=True,
        monitor='val_loss',                                     
        verbose=1
    )
    text_generator = TrainTextGenerator(combined_vocab=combined_vocab)
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    )
    return [model_checkpoint_callback, text_generator, early_stopping_callback]

# Train the model
model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=get_callbacks(),
)

model.save(MODEL_PATH)

In [None]:
class TestTextGenerator(callbacks.Callback):
    def __init__(self, combined_vocab, top_k=15):
        super().__init__()
        self.index_to_word = combined_vocab
        self.word_to_index = {word: index for index, word in enumerate(combined_vocab)}
        
    def set_model(self, model):
        self.model = model

    def sample_from(self, probs, temperature):
        scaled_probs = probs ** (1 / temperature)
        scaled_probs /= np.sum(scaled_probs)
        return np.random.choice(len(scaled_probs), p=scaled_probs), scaled_probs
                
    def generate(self, start_prompt, max_tokens, temperature):
        start_tokens = [self.word_to_index.get(word, 1) for word in start_prompt.split()]
        generated_tokens = []
        for _ in range(max_tokens):
            x = np.array([start_tokens])
            outputs = self.model.predict(x, verbose=0)
            y = outputs[0]
            sample_token, _ = self.sample_from(y[0], temperature)

            generated_tokens.append(sample_token)
            start_tokens.append(sample_token)
            if sample_token == 0:
                break
        generated_text = " ".join([self.index_to_word[token] for token in generated_tokens if token < len(self.index_to_word)])
        return generated_text

    def on_test_end(self, logs=None):
        generated_text = self.generate("This year has been", max_tokens=100, temperature=1)
        print(f"Generated text: {generated_text}")

# Initialize TestTextGenerator with your vocabulary
test_text_gen = TestTextGenerator(combined_vocab)
# Add it to your model's callbacks
model = load_model(MODEL_PATH, custom_objects={
    "TransformerBlock": TransformerBlock,
    "TokenAndPositionEmbedding": TokenAndPositionEmbedding,
    "CustomSchedule": CustomSchedule
})
model.evaluate(test_ds, callbacks=[test_text_gen])