In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import custom_object_scope
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras import layers, models, losses, callbacks

import pandas as pd
import numpy as np
import configparser
import os
from utils.process import main
data = pd.read_csv('../sampled_data.csv').sample(25000)
train_ds, val_ds, test_ds, combined_vocab = main(data)

#### i. General Generative

In [None]:
from generative_text.general_generative_keras.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative_text.general_generative_keras.train import train_model, TrainTextGenerator, CustomSchedule
from generative_text.general_generative_keras.evaluate import TextGenerator, CustomSchedule

config = configparser.ConfigParser()
config.read('./generative_text/configkeras.ini')
params = config["params"]
epochs = int(params['epochs']) 

LOAD_MODEL = False
MODEL_PATH = './models/general_generative/model_1.h5'

if LOAD_MODEL and os.path.exists(MODEL_PATH):
    model = train_model(preload_model=True, model_path=MODEL_PATH)
else:
    model = train_model(preload_model=False, model_path=MODEL_PATH)

def get_callbacks():
    model_checkpoint_callback = ModelCheckpoint(
        filepath="./models/general_generative/weights.{epoch:02d}-{val_loss:.2f}.ckpt",
        save_weights_only=False,
        save_best_only=True,
        monitor='val_loss',                                     
        verbose=1
    )
    text_generator = TrainTextGenerator(index_to_word=combined_vocab)
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    )
    return [model_checkpoint_callback, text_generator, early_stopping_callback]

model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=get_callbacks(),
)
model.save(MODEL_PATH)

In [None]:
from generative_text.general_generative_keras.tnn import TransformerBlock, TokenAndPositionEmbedding
from generative.general_generative_keras.evaluate import TextGenerator, CustomSchedule

MODEL_PATH = './models/general_generative/model_1.h5'
with custom_object_scope({'CustomSchedule': CustomSchedule, 'TransformerBlock': TransformerBlock, 'TokenAndPositionEmbedding': TokenAndPositionEmbedding}):
    model = load_model(MODEL_PATH)

test_text_gen = TextGenerator(model=model, index_to_word=combined_vocab, top_k=15, generation_type='general', sampling_type='top_k')
info = test_text_gen.generate("Today in the news", max_tokens=50, temperature=1.0)

#### ii. Custom Generative

In [None]:
def update_replies(sample_size=150000, raw_data_path='/content/drive/MyDrive/research/chanscope/data/replies_raw_2.csv', replies_path='/content/drive/MyDrive/research/chanscope/data/replies/replies.csv'):
    # Read and sample the raw data
    raw_data = pd.read_csv(raw_data_path).sample(sample_size)

    # Prepare the data
    prepared_data = prepare_data(raw_data)
    thread_headers = prepared_data.dropna(subset=['text_clean','posted_comment'])[['thread_id', 'thread_header', 'posted_comment', 'posted_date_time']]

    # Find and augment dialogs
    new_replies = find_dialogs(thread_headers)
    new_replies = augment_dialogs(new_replies, prepared_data)
    new_replies = new_replies.dropna()

    # Read the existing replies and append new ones
    complete_replies = pd.read_csv(replies_path)
    complete_replies = pd.concat([complete_replies, new_replies]).drop_duplicates().reset_index(drop=True)

    # Save the updated replies
    complete_replies.to_csv(replies_path, index=False)

    # Remove the sampled data from the original dataset and save it
    remaining_data = pd.read_csv(raw_data_path)
    remaining_data = remaining_data.loc[~remaining_data.index.isin(raw_data.index)]
    remaining_data.to_csv(raw_data_path, index=False)
    return complete_replies, remaining_data

# Run the function
remaining_data,complete_replies = update_replies()

In [None]:
import pandas as pd
from utils.fnProcessing import find_dialogs, augment_dialogs,view_shapes

import tensorflow as tf
from generative_text.general_chat_custom.preprocessing import DirectoryManager  
from generative_text.general_chat_custom.preprocessing import initialize_and_prepare  
from generative_text.general_chat_custom.processing import process_and_load_data
from generative_text.general_chat_custom.evaluate import plot_text_pair_distribution, count_tokens_and_lengths,plot_history
import pandas as pd
import configparser

config = configparser.ConfigParser()
config.read('./generative_text/configcustom.ini')
config_params = config['params']
config_paths = config['paths']
paths = {key: config_paths[key] for key in config_paths}
base_directory = paths['metadata_path']
params = {key: config_params[key] for key in config_params}
max_len = int(params['max_len'])
vocab_size = int(params['vocab_size'])
embedding_dim = int(params['embedding_dim'])
num_heads = int(params['n_heads'])
num_layers = int(params['n_layers'])
key_dim = int(params['key_dim'])
ff_dim = int(params['feed_forward_dim'])
dropout_rate = float(params['dropout'])
warmup_steps = int(params['warmup_steps'])
activation = params['activation']
epoch = int(params['epochs'])

data_path = '../replies.csv'
replies = pd.read_csv(f'{data_path}').drop_duplicates().sample(2500)
config_directories = DirectoryManager.generate_config(base_directory)
DirectoryManager.create_directories(config_directories) 
# Load supporting data
text_pairs, voc_comment, voc_response_comment = initialize_and_prepare(base_directory, replies)

# Load data
train_ds, val_ds, test_ds, comment_vectorizer, response_comment_vectorizer  = process_and_load_data(replies)
comment_tokens, response_comment_tokens, comment_maxlen, response_maxlen = count_tokens_and_lengths(text_pairs)
view_shapes(train_ds)

In [None]:
from generative_text.general_chat_custom.tnn import transformer, masked_loss, masked_accuracy
from generative_text.general_chat_custom.tnn import CustomSchedule
from generative_text.general_chat_custom.PositionalEmbedding import PositionalEmbedding
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow as tf
import os

model_path = './models/'
model_name = 'best_model.h5'
model_full_path = os.path.join(model_path, model_name)

def get_callbacks(model_path, model_name, patience=5):
    early_stopping = EarlyStopping(monitor='val_masked_accuracy', patience=patience, restore_best_weights=False)
    model_checkpoint = ModelCheckpoint(filepath=os.path.join(model_path, model_name), monitor='val_loss', save_best_only=True, save_weights_only=False)
    return [early_stopping, model_checkpoint]

# Load or create model
if os.path.isfile(model_full_path):
    print("Best model checkpoint found. Loading...")
    transformer_model = tf.keras.models.load_model(model_full_path, custom_objects={
        'masked_loss': masked_loss,
        'masked_accuracy': masked_accuracy,
        'CustomSchedule': CustomSchedule,
        'PositionalEmbedding': PositionalEmbedding
    })
else:
    print("No model checkpoint found. Creating a new one.")
    transformer_model = transformer(
        num_layers=num_layers,
        num_heads=num_heads,
        key_dim=key_dim,
        ff_dim=ff_dim,
        vocab_size_src=len(comment_tokens),
        vocab_size_tgt=len(response_comment_tokens),
        dropout=dropout_rate
    )

# Compile model
lr_schedule = CustomSchedule(key_dim, warmup_steps)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipnorm=1.0, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
transformer_model.compile(loss=masked_loss, optimizer=optimizer, metrics=[masked_accuracy])
# Train model
callbacks = get_callbacks(model_path, model_name)
history = transformer_model.fit(train_ds, epochs=epoch, validation_data=val_ds, callbacks=callbacks)

In [None]:
plot_history(history)

In [None]:
import tensorflow as tf
import os
import re
import random
from generative_text.general_chat_custom.tnn import transformer, masked_loss, masked_accuracy

from generative_text.general_chat_custom.tnn import CustomSchedule
from generative_text.general_chat_custom.PositionalEmbedding import PositionalEmbedding

custom_objects = {
    "PositionalEmbedding": PositionalEmbedding,
    "CustomSchedule": CustomSchedule,
    "masked_loss": masked_loss,
    "masked_accuracy": masked_accuracy
}

best_model_path = "./models/best_model.h5"

# Load the trained model with custom objects
if os.path.exists(best_model_path):
    print(f"Loading best model from {best_model_path}")
    with tf.keras.utils.custom_object_scope(custom_objects):
        model = tf.keras.models.load_model(best_model_path)
else:
    print("Best model not found. Please check the path or train the model.")

# Define the translate function
def translate(sentence, max_len=max_len):
    """Create the translated sentence."""
    # Tokenize the input sentence
    enc_tokens = comment_vectorizer([sentence])
    enc_tokens = tf.reshape(enc_tokens, (1, -1)) 
    lookup = list(response_comment_vectorizer.get_vocabulary())
    start_sentinel, end_sentinel = "[start]", "[end]"
    output_sentence = [start_sentinel]

    # Generate the output sentence
    for i in range(max_len):
        vector = comment_vectorizer([" ".join(output_sentence)])
        dec_tokens = tf.reshape(vector[:, :-1], (1, -1))
        pred = model([enc_tokens, dec_tokens])
        # Debugging: Check the shape of the prediction tensor
        print(f"Prediction shape: {pred.shape}")
        if i >= pred.shape[1]:
            print(f"Index {i} is out of bounds for prediction with shape {pred.shape}.")
            break
        word_index = tf.argmax(pred[0, i, :], axis=-1).numpy()
        word = lookup[word_index]
        output_sentence.append(word)
        if word == end_sentinel:
            break
    return " ".join(output_sentence[1:-1])  # Exclude start and end sentinels

# Test the translate function
test_count = 5
for n in range(test_count):
    thread, comment = random.choice(text_pairs)
    translated = translate(thread)
    print(f"Test {n+1}:")
    print(f"Thread: {thread}")
    print(f"Expected Comment: {comment}")
    print(f"Translated Comment: {translated}")
    print()