In [None]:
import os, shutil
import io
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from tensorflow import keras
from keras import layers
from tqdm import tqdm
import numpy as np
from PIL import ImageTk, Image
import matplotlib.pyplot as plt
import PySimpleGUI as sg
import tkinter as tk

latent_dim = 128
opt_gen = keras.optimizers.Adam(1e-4)
opt_disc = keras.optimizers.Adam(1e-4)
loss_fn = keras.losses.BinaryCrossentropy()


def gui():
    global modelG  
    global discriminatorG
    file_types = [("JPEG (*.jpg)", "*.jpg"),
              ("All files (*.*)", "*.*")]
    
    layout = [
        [sg.Text("Press Beach or Mountain"),
         sg.Button("Mountains"),
         sg.Button("Beach"),
         ],
        [sg.Image(key="-IMAGE-"),
         sg.VSeperator(),
        sg.Image(key="-IMAGE1-"),],
         

        [
            sg.Text("Image File"),
            sg.Input(size=(25, 1), key="-FILE-"),
            
            sg.FileBrowse(file_types=file_types),
            sg.Button("Load Image"),
            
        ],
        [sg.Button("Run Program")],
        ]
    window = sg.Window('Art Generation', layout, size=(500,500))
    
   

    while True:
        event, values = window.read()
        if event == "Exit" or event == sg.WIN_CLOSED:
            break
        if event == "Run Program":
            img = handleImageGen()
            image = Image.open(img)
            image1 = ImageTk.PhotoImage(image=image)
            window["-IMAGE1-"].update(data=image1)
        if event == "Mountains":
            modelG = 'save_model/mountains_model'
            discriminatorG = 'save_discriminator/mountains_discriminator'
        if event == "Beach":
            modelG = 'save_model/beach_model'
            discriminatorG = 'save_discriminator/beach_discriminator'
        if event == "Load Image":
            filename = values["-FILE-"]
            print(filename)
            global fileLocation
            fileLocation = filename
            
            if os.path.exists(filename):
                image = Image.open(values["-FILE-"])
                image.thumbnail((400, 400))
                bio = io.BytesIO()
                # Actually store the image in memory in binary 
                image.save(bio, format="PNG")
                # Use that image data in order to 
                window["-IMAGE-"].update(data=bio.getvalue())
    window.close()

    
    
def directory_funct():
    dataset = keras.preprocessing.image_dataset_from_directory(
    directory = 'celeb_dataset/user_input', label_mode=None, image_size=(128,128),batch_size=32,
    shuffle=True
    ).map(lambda x: x/255.0)
    return dataset
    
    
    
def discriminator_funct(value):
    # discriminator = keras.Sequential(
    #     [
    #         keras.Input(shape=(128,128,3)),
    #         layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
    #         layers.LeakyReLU(0.2),
    #         layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
    #         layers.LeakyReLU(0.2),
    #         layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
    #         layers.LeakyReLU(0.2),
    #         layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
    #         layers.LeakyReLU(0.2),
    #         layers.Flatten(),
    #         layers.Dropout(0.2),
    #         layers.Dense(1, activation="sigmoid"),
    #     ]
    # )
    discriminator = tf.keras.models.load_model(value)
    return discriminator
def model_funct(value):
    # generator = keras.Sequential(
    #     [
    #         layers.Input(shape=(latent_dim,)),
    #         layers.Dense(8*8*128),
    #         layers.Reshape((8, 8, 128)),
    #         layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
    #         layers.LeakyReLU(0.2),
    #         layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
    #         layers.LeakyReLU(0.2),
    #         layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
    #         layers.LeakyReLU(0.2),
    #         layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
    #         layers.LeakyReLU(0.2),
    #         layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
    #     ]
    # )
    # print(value)
    generator = tf.keras.models.load_model(value)
    # generator.summary()
    return generator

def image_generation(generator, discriminator, dataset):
    for epoch in range(10):
        for idx, real in enumerate(tqdm(dataset)):
            batch_size = real.shape[0]
            random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
            fake = generator(random_latent_vectors)

            
            

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z))
            with tf.GradientTape() as disc_tape:
                loss_disc_real = loss_fn(tf.ones((batch_size, 1)), discriminator(real))
                loss_disc_fake = loss_fn(tf.zeros(batch_size, 1), discriminator(fake))
                loss_disc = (loss_disc_real + loss_disc_fake)/2

            grads = disc_tape.gradient(loss_disc, discriminator.trainable_weights)
            opt_disc.apply_gradients(
                zip(grads, discriminator.trainable_weights)
            )

        ### Train Generator min log(1 - D(G(z)) <-> max log(D(G(z))
            with tf.GradientTape() as gen_tape:
                fake = generator(random_latent_vectors)
                output = discriminator(fake)
                loss_gen = loss_fn(tf.ones(batch_size, 1), output)

            grads = gen_tape.gradient(loss_gen, generator.trainable_weights)
            opt_gen.apply_gradients(
                zip(grads, generator.trainable_weights)
            )
        if idx % 100 == 0:
            img = keras.preprocessing.image.array_to_img(fake[0])
            
            if epoch == 9:
                img.save(f"generated_images/user_input_images/generated_img.png")
    
            
def handleImageGen():
    discriminator = discriminator_funct(discriminatorG)
    generator = model_funct(modelG)
    target_dataset = 'celeb_dataset/user_input/'
    for x in os.listdir(target_dataset):
        if x.endswith('.png'):
            os.unlink(target_dataset + x)
    shutil.move(fileLocation, 'celeb_dataset/user_input/')
    dataset = directory_funct()
    target_generated = 'generated_images/user_input_images/'
    for j in os.listdir(target_generated):
        if j.endswith('.png'):
            os.unlink(target_generated + j)
    image_generation(generator, discriminator, dataset)
    return 'generated_images/user_input_images/generated_img.png'
    

def main():
    gui()

if __name__ == "__main__":
    main()
