In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import custom_object_scope
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras import layers, models, losses, callbacks

import pandas as pd
import numpy as np
import configparser
import os
from utils.process import main
data = pd.read_csv('../sampled_data.csv').sample(25000)
train_ds, val_ds, test_ds, combined_vocab = main(data)

#### i. General Generative

In [None]:
from generative_text.general_generative_keras.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative_text.general_generative_keras.train import train_model, TrainTextGenerator, CustomSchedule
from generative_text.general_generative_keras.evaluate import TextGenerator, CustomSchedule

config = configparser.ConfigParser()
config.read('./generative_text/configkeras.ini')
params = config["params"]
epochs = int(params['epochs']) 

LOAD_MODEL = False
MODEL_PATH = './models/general_generative/model_1.h5'

if LOAD_MODEL and os.path.exists(MODEL_PATH):
    model = train_model(preload_model=True, model_path=MODEL_PATH)
else:
    model = train_model(preload_model=False, model_path=MODEL_PATH)

def get_callbacks():
    model_checkpoint_callback = ModelCheckpoint(
        filepath="./models/general_generative/weights.{epoch:02d}-{val_loss:.2f}.ckpt",
        save_weights_only=False,
        save_best_only=True,
        monitor='val_loss',                                     
        verbose=1
    )
    text_generator = TrainTextGenerator(index_to_word=combined_vocab)
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    )
    return [model_checkpoint_callback, text_generator, early_stopping_callback]

model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=get_callbacks(),
)
model.save(MODEL_PATH)

In [None]:
from generative_text.general_generative_keras.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.general_generative_keras.evaluate import TextGenerator, CustomSchedule

MODEL_PATH = './models/general_generative/model_1.h5'
with custom_object_scope({'CustomSchedule': CustomSchedule, 'TransformerBlock': TransformerBlock, 'TokenAndPositionEmbedding': TokenAndPositionEmbedding}):
    model = load_model(MODEL_PATH)

test_text_gen = TextGenerator(model=model, index_to_word=combined_vocab, top_k=15, generation_type='general', sampling_type='top_k')
info = test_text_gen.generate("Today in the news", max_tokens=50, temperature=1.0)

#### ii. Custom Generative

In [None]:
import tensorflow as tf
from general_chat_custom.preprocessing import DirectoryManager  
from general_chat_custom.preprocessing import initialize_and_prepare  
from general_chat_custom.processing import process_and_load_data

import pandas as pd
import configparser

config = configparser.ConfigParser()
config.read('./generative_text/configcustom.ini')
config_params = config['params']
params = {key: config_params[key] for key in config_params}
base_directory = params['dataset_path']
max_len = int(params['max_len'])
vocab_size = int(params['vocab_size'])
embedding_dim = int(params['embedding_dim'])
num_heads = int(params['n_heads'])
num_layers = int(params['n_layers'])
key_dim = int(params['key_dim'])
ff_dim = int(params['feed_forward_dim'])
dropout_rate = float(params['dropout'])
warmup_steps = int(params['warmup_steps'])
activation = params['activation']
epochs = int(params['epochs'])
replies = pd.read_csv('/content/drive/MyDrive/research/chanscope/data/replies/replies.csv').drop_duplicates()
config_directories = DirectoryManager.generate_config(base_directory)
DirectoryManager.create_directories(config_directories) 
meta_data_dir = config_directories['meta_data_dir']
# Load supporting data
text_pairs, voc_comment, voc_response_comment = initialize_and_prepare(base_directory, replies)
# Load data
train_ds, val_ds, test_ds, thread_vectorizer, comment_vectorizer  = process_and_load_data(replies)

In [None]:
from model_evaluation import plot_history, evaluate_model, plot_text_pair_distribution, count_tokens_and_lengths

# Assuming `history` is the variable holding your training history
plot_history(history)
# Assuming you have `x_test` and `y_test` as your test datasets
precision, recall, f1 = evaluate_model(model, x_test, y_test)
# Assuming `text_pairs` is a list of tuples containing your text data
plot_text_pair_distribution(text_pairs)
count_tokens_and_lengths(text_pairs)

In [None]:
from general_chat_custom.tnn import transformer, masked_loss, masked_accuracy
from general_chat_custom.tnn import CustomSchedule
from general_chat_custom.PositionalEmbedding import PositionalEmbedding
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow as tf
import os

# Constants
model_path = './drive/MyDrive/research/chanscope/generative_custom_models/model_1.h5'

# Function to get callbacks
def get_callbacks():
    early_stopping = EarlyStopping(monitor='val_masked_accuracy', patience=5, restore_best_weights=True)
    model_checkpoint = ModelCheckpoint(filepath= './drive/MyDrive/research/chanscope/generative_custom_models/weights.{epoch:02d}-{val_loss:.2f}.ckpt', save_best_only=True)
    return [early_stopping, model_checkpoint]

# Load or create model
if os.path.exists(model_path):
    print("Model found. Loading...")
    model = tf.keras.models.load_model(model_path, custom_objects={'masked_loss': masked_loss, 'masked_accuracy': masked_accuracy})
else:
    print("No model found. Creating a new one.")
    model = transformer(num_layers, num_heads, max_len, key_dim, ff_dim, len(comment_tokens), len(response_comment_tokens), dropout_rate)

# Compile model
lr = CustomSchedule(key_dim)
optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
model.compile(loss=masked_loss, optimizer=optimizer, metrics=[masked_accuracy])

# Train model
callbacks = get_callbacks()
history = model.fit(train_ds, epochs=epochs, validation_data=val_ds, callbacks=callbacks)

# Save model
model.save(model_path)

In [None]:
from matplotlib import pyplot as plt

fig, axs = plt.subplots(2, sharex=True, figsize=(10, 6))
fig.suptitle('Training history')
# Get the actual number of training epochs
actual_epochs = len(history.history["loss"])
x = list(range(1, actual_epochs + 1))
axs[0].plot(x, history.history["loss"], alpha=0.5, label="loss")
axs[0].plot(x, history.history["val_loss"], alpha=0.5, label="val_loss")
axs[0].set_ylabel("Loss")
axs[1].plot(x, history.history["masked_accuracy"], alpha=0.5, label="masked_accuracy")
axs[1].plot(x, history.history["val_masked_accuracy"], alpha=0.5, label="val_masked_accuracy")
axs[1].set_ylabel("Masked Accuracy")
plt.xlabel("Epochs")
plt.legend()
plt.show()

In [None]:
# Load the trained model
import random
custom_objects = {"PositionalEmbedding": PositionalEmbedding,
                  "CustomSchedule": CustomSchedule,
                  "masked_loss": masked_loss,
                  "masked_accuracy": masked_accuracy}

with tf.keras.utils.custom_object_scope(custom_objects):
    model = tf.keras.models.load_model(f'./models/model_1.keras')

# Translate function
# Translate function
def translate(sentence):
    """Create the translated sentence"""
    enc_tokens = thread_vectorizer([sentence])
    enc_tokens = tf.reshape(enc_tokens, (1, -1))  # Reshape to include batch dimension
    lookup = list(comment_vectorizer.get_vocabulary())
    start_sentinel, end_sentinel = "[start]", "[end]"
    output_sentence = [start_sentinel]

    for i in range(max_len):
        vector = comment_vectorizer([" ".join(output_sentence)])
        dec_tokens = tf.reshape(vector[:, :-1], (1, -1))  # Reshape to include batch dimension
        pred = model([enc_tokens, dec_tokens])
        
        word_index = tf.argmax(pred[0, i, :]).numpy()
        word = lookup[word_index]
        output_sentence.append(word)
        
        if word == end_sentinel:
            break
            
    return output_sentence

# Testing
test_count = 5
for n in range(test_count):
    thread, comment = random.choice(text_pairs)
    translated = translate(thread)
    print(f"Test {n}:")
    print(f"{thread}")
    print(f"== {comment}")
    print(f"-> {' '.join(translated)}")
    print()