<a href="https://colab.research.google.com/github/nihermann/Pokemaenner/blob/main/Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
##################################### Console #####################################
# !git clone https://github.com/nihermann/Pokemaenner.git
# %cd Pokemaenner/
# !git status
###################################################################################

Cloning into 'Pokemaenner'...
remote: Enumerating objects: 277, done.[K
remote: Counting objects: 100% (277/277), done.[K
remote: Compressing objects: 100% (217/217), done.[K
remote: Total 277 (delta 154), reused 135 (delta 57), pack-reused 0[K
Receiving objects: 100% (277/277), 1.44 MiB | 2.08 MiB/s, done.
Resolving deltas: 100% (154/154), done.
/content/Pokemaenner


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
#@title # Using Gan to create new Pokemon
import tensorflow as tf
from manager import GANManager
from data import DataGenerator
import gan

In [None]:
#@title ## Data Settings
image_shape = (64,64) #@param
image_path = "/content/drive/MyDrive/images/" #@param {type:"string"}
batch_size = 32 #@param {type:"integer"}
validation_split = 0.1 #@param {type:"slider", min:0, max:0.5, step:0.01}
shuffle = True #@param {type:"boolean"}
data_augmentation = False #@param {type:"boolean"}


data = DataGenerator(
    img_path=image_path,
    batch_size=batch_size,
    img_height=image_shape[0],
    img_width=image_shape[1],
    validation_split=validation_split,
    shuffle=shuffle
)

Found 10135 files belonging to 14 classes.
Using 9122 files for training.
Found 10135 files belonging to 14 classes.
Using 1013 files for validation.


In [None]:
#@title Generator Arguments
latentspace = 250 #@param {type:"slider", min:2, max:1000, step:1}


generator = GAN.Generator(
    latentspace=latentspace
)

In [None]:
#@title ## Discriminator Arguments
discriminator = GAN.Discriminator(
    input_shape=(None, image_shape[0], image_shape[1], 4)
)

In [None]:
#@title ## Hyperparameters


loss_function = "Binary Cross Entropy" #@param ["Binary Cross Entropy", "Mean Squared Error"]
optimizer = "Adam" #@param ["Adam", "RMSprop", "SGD"]
learning_rate = 0.001 #@param {type:"number"}


## Dropdown equivalents
loss_functions = {
    "Binary Cross Entropy": tf.keras.losses.BinaryCrossentropy(),
    "Mean Squared Error": tf.keras.losses.MSE
}

optimizers = {
    "Adam": tf.keras.optimizers.Adam(learning_rate),
    "RMSprop": tf.keras.optimizers.RMSprop(learning_rate),
    "SGD": tf.keras.optimizers.SGD
}


## Final
kwargs = {
    "batch_size": batch_size,
    "loss": loss_functions[loss_function],
    "optimizer": optimizers[optimizer]
}

manager = GANManager(
    kwargs=kwargs,
    generator=generator,
    discriminator=discriminator,
    data=data
)

In [None]:
#@title # Training Parameters
epochs = 100 #@param {type:"integer"}
samples_per_epoch = 10000 #@param {type:"integer"}
start_from_existing_models =  False #@param {type:"boolean"}
print_every =  1#@param {type:"integer"}
print_verbose = "no_prints" #@param ["no_prints", "print_after_each_epoch", "progressbar"]
print_verbose = {"no_prints": 0, "print_after_each_epoch": 2, "progressbar": 1}[print_verbose]

#@markdown ## Callbacks
callbacks = {}
save_models = False #@param {type:"boolean"}
if save_models:
    model_path = "./models" #@param ["./models"] {allow-input: true}
    save_model_every = 0 #@param {type:"integer"}
    save_weights_only = False #@param {type:"boolean"}

    callbacks["ModelCheckpoint"] = {
        "filepath": model_path,
        "save_weights_only": save_weights_only,
    }

save_pictures_every = 4 #@param {type:"integer"}
how_many_pictures_to_save = 10 #@param {type:"integer"}
use_tensorboard = False #@param {type:"boolean"}
if use_tensorboard:
    log_dir = "./logs" #@param ["./logs"] {allow-input: true}
    update_frequency = "epoch" #@param ["batch", "epoch"] {allow-input: true}
    callbacks["TensorBoard"] = {
        "log_dir": log_dir,
        "update_freq" : update_frequency
    }



manager.train(
    epochs=epochs,
    samples_per_epoch=samples_per_epoch,
    trainings_frequency=trainings_frequency,
    print_every=print_every,
    print_verbose=print_verbose,
    save_pictures_every=save_pictures_every,
    how_many_pictures_to_save=how_many_pictures_to_save,
    save_model_every=save_model_every
    )

|T| ---------------------------------------------------------------> |V| -------------------------------->
|D| Mean G-Loss: 12.757805824279785, Mean D-Loss: 0.005183156114071608, Mean D-Loss for real images: 0.006516627036035061, Mean D-Loss for generated images: 0.0  Mean D-Accuracy for real images: 0, Mean D-Accuracy for generated images: 0 || 
|T| ---------------------------------------------------------------> |V| -------------------------------->
|D| Mean G-Loss: 15.424947738647461, Mean D-Loss: 1.2884834177384619e-07, Mean D-Loss for real images: 0.0003814099763985723, Mean D-Loss for generated images: 0.0  Mean D-Accuracy for real images: 0, Mean D-Accuracy for generated images: 0 || 
|T| ---------------------------------------------------------------> |V| -------------------------------->
|D| Mean G-Loss: 15.424947738647461, Mean D-Loss: 1.024165641183572e-08, Mean D-Loss for real images: 0.00012278366193640977, Mean D-Loss for generated images: 0.0  Mean D-Accuracy for real im