In [31]:
import os
import torch
import pickle
import numpy as np
import tkinter as tk
from tkinter import ttk
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from transformers import TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2Tokenizer


In [2]:
dataset_path = 'dataset/'
yoda_file = 'yoda-corpus.csv'
model_path = 'model/trained_model'

In [36]:
def send_message():
    message = input_field.get()
    selected_option = radio_var.get()

    if selected_option == 1:
        output_text.insert(tk.END, "user: " + message + "\n")
        response = getResponsefromGan(message)
        output_text.insert(tk.END, "GAN: " + response + "\n")
    elif selected_option == 2:
        output_text.insert(tk.END, "user: " + message + "\n")
        response = getResponsefromPretrainedModel(message,tokenizer_untrained,model_untrained)
        output_text.insert(tk.END, "Untrained Model:" + response + "\n")
    elif selected_option == 3:
        output_text.insert(tk.END, "user: " + message + "\n")
        response = getResponsefromPretrainedModel(message,tokenizer_trained,model_trained)
        output_text.insert(tk.END, "Pretrained Model (Fine Tuned): " + response + "\n")

In [37]:
def getResponsefromPretrainedModel(message,tokenizer,model):
    input_ids = tokenizer.encode(message, add_special_tokens=True, return_tensors="pt")
    attention_mask = torch.ones_like(input_ids)
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=30, num_return_sequences=1)
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response

In [38]:
def getResponsefromGan(message):
    tokens = preprocess_input(message)
    response = generate_response(tokens)
    response = postprocess_response(response)
    return response

In [39]:
model_path = "model/trained_model"
gan_model_dir = "model/gan_model"
gan_folder_name = "model_3"
print(gan_model_dir)
print(gan_folder_name)
gan_full_path = os.path.join(gan_model_dir, gan_folder_name)
print(gan_full_path)

tokenizer_trained = GPT2Tokenizer.from_pretrained(model_path)
tokenizer_untrained = GPT2Tokenizer.from_pretrained("gpt2")

model_trained = GPT2LMHeadModel.from_pretrained(model_path)
model_untrained = GPT2LMHeadModel.from_pretrained("gpt2",pad_token_id=tokenizer_untrained.eos_token_id)

model_trained.eval()
model_untrained.eval()





# Load the compiled Generator model
print(gan_full_path)
model_gan = load_model(os.path.join(gan_full_path, "generator_model_compiled.h5"))

# Load the Tokenizer
with open(os.path.join(gan_full_path, "tokenizer.pkl"), "rb") as tokenizer_file:
    tokenizer_gan_data = pickle.load(tokenizer_file)

tokenizer_gan = tokenizer_gan_data['tokenizer']
max_sequence_length_gan = tokenizer_gan_data['max_sequence_length']

model/gan_model
model_3
model/gan_model\model_3
model/gan_model\model_3


In [40]:
def preprocess_input(user_input):
    # Tokenize the user input
    tokens = user_input.strip().split()

    # Convert tokens to lowercase
    tokens = [token.lower() for token in tokens]

    # Return the preprocessed input
    return tokens


def generate_response(input_tokens, generator_model = model_gan ):
    # Convert input tokens to numerical representation
    input_sequence = tokenizer_gan.texts_to_sequences([input_tokens])
    input_sequence = pad_sequences(input_sequence, maxlen=max_sequence_length_gan-1)
    # Generate response using the generator model
    generated_sequence = model_gan.predict(input_sequence)

    # Convert numerical representation back to tokens
    # generated_tokens = tokenizer.sequences_to_texts(generated_sequence)[0].split()

    generated_tokens = [tokenizer_gan.index_word.get(index, "") for index in np.argmax(generated_sequence, axis=-1)[0]]
    generated_tokens = [token for token in generated_tokens if token]

    # print('generated_sequence function:',generated_tokens)

    # Return the generated response tokens
    return generated_tokens


def postprocess_response(response_tokens):
    # Convert tokens to string
    response_text = ' '.join(response_tokens)

    # Return the postprocessed response
    return response_text


In [41]:
# Create the main window
window = tk.Tk()
window.title("Simple UI Example")

# Calculate the center position of the window
window_width = 400
window_height = 300
screen_width = window.winfo_screenwidth()
screen_height = window.winfo_screenheight()
x_coordinate = int((screen_width/2) - (window_width/2))
y_coordinate = int((screen_height/2) - (window_height/2))

# Set the window position
window.geometry(f"{window_width}x{window_height}+{x_coordinate}+{y_coordinate}")

# Create a division using LabelFrame
form_frame = ttk.LabelFrame(window, text="Form", style="Custom.TLabelframe")
form_frame.pack(padx=10, pady=10)

# Custom style for LabelFrame
style = ttk.Style()
style.configure("Custom.TLabelframe", background="white")

# Create the "Input" label
input_label = ttk.Label(form_frame, text="User:")
input_label.grid(row=0, column=0, padx=10, pady=(0, 5), sticky="w")

# Create the input field
input_field = ttk.Entry(form_frame, font=('Arial', 12))
input_field.grid(row=1, column=0, columnspan=3, padx=10, pady=(0, 10), sticky="w")

# Create the radio buttons
radio_var = tk.IntVar()

radio_frame = ttk.Frame(form_frame)
radio_frame.grid(row=2, column=0, padx=10, pady=10, columnspan=3)

radio_btn1 = ttk.Radiobutton(radio_frame, text="GAN", variable=radio_var, value=1)
radio_btn1.grid(row=0, column=0, padx=5)

radio_btn2 = ttk.Radiobutton(radio_frame, text="Untrained Model", variable=radio_var, value=2)
radio_btn2.grid(row=0, column=1, padx=5)

radio_btn3 = ttk.Radiobutton(radio_frame, text="Pretrained Model (Fine Tuned)", variable=radio_var, value=3)
radio_btn3.grid(row=0, column=2, padx=5)

# Create the send button
send_button = ttk.Button(form_frame, text="Send", command=send_message)
send_button.grid(row=3, column=0, padx=10, pady=10)

# Create the output field
output_label = ttk.Label(form_frame, text="Output:")
output_label.grid(row=4, column=0, padx=10, pady=(5, 0), sticky="w")

output_text = tk.Text(form_frame, height=5, font=('Arial', 12), bd=1, relief=tk.SOLID)
output_text.grid(row=5, column=0, columnspan=3, padx=10, pady=(0, 10), sticky="nsew")

# Set the input field width to match the output field
input_field.config(width=output_text["width"])

# Configure grid weights
form_frame.columnconfigure(0, weight=1)
form_frame.rowconfigure(5, weight=1)

# Start the main loop
window.mainloop()

