# Import the packages and define the (environment) variables

In [None]:
import os

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
dir_path = "F:\\celeb_a\\"

In [None]:
images_path = dir_path + "img_align_celeba\\"

In [None]:
labels_path = dir_path + "celeb_a_attrs.csv"

In [None]:
batch_size = 32
noise_size = 54600

# Create the dataset pipeline

## Read the labels

*Only needed for Conditional GANs*

In [None]:
#labels = pd.read_csv(labels_path)

## Create a list of all the files

In [None]:
image_file_list = []

for rootpath, _, filename in os.walk(images_path):
    for name in filename:
        fullpath = os.path.join(rootpath, name)
        image_file_list.append(fullpath)

## Make a tf.data dataset based on the file list

In [None]:
image_dataset = tf.data.Dataset.from_tensor_slices(image_file_list)
image_dataset = image_dataset.map(lambda item : tf.io.read_file(filename=item))
image_dataset = image_dataset.map(lambda item : tf.io.decode_jpeg(contents=item))
image_dataset = image_dataset.map(lambda item : tf.cast(x=item, dtype=tf.float32))
image_dataset = image_dataset.map(lambda item : tf.divide(x=item, y=255.0)) # Scale to [0.0, 1.0]
image_dataset = image_dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=False)
image_dataset = image_dataset.batch(batch_size=batch_size, drop_remainder=True, num_parallel_calls=6)
image_dataset = image_dataset.prefetch(buffer_size=1000)

# Build the model

## Build the generator model

In [None]:
gen_in = keras.layers.Input(shape=[54600])
g = keras.layers.Reshape(target_shape=(26, 21, 100))(gen_in)
g = keras.layers.Conv2DTranspose(filters=30, kernel_size=(4), strides=(2), padding="valid", kernel_initializer="lecun_normal")(g)
g = keras.layers.BatchNormalization()(g)
g = keras.layers.Activation(activation="elu")(g)
g = keras.layers.Conv2DTranspose(filters=60, kernel_size=(3), strides=(2), padding="valid", kernel_initializer="lecun_normal")(g)
g = keras.layers.BatchNormalization()(g)
g = keras.layers.Activation(activation="elu")(g)
g = keras.layers.Conv2DTranspose(filters=90, kernel_size=(3), strides=(2), padding="valid", kernel_initializer="lecun_normal")(g)
g = keras.layers.BatchNormalization()(g)
g = keras.layers.Activation(activation="elu")(g)
g = keras.layers.SeparableConv2D(filters=150, kernel_size=(2), strides=(1), padding="valid", kernel_initializer="lecun_normal")(g)
g = keras.layers.BatchNormalization()(g)
g = keras.layers.Activation(activation="elu")(g)
g = keras.layers.SeparableConv2D(filters=200, kernel_size=(2), strides=(1), padding="same", kernel_initializer="lecun_normal")(g)
g = keras.layers.BatchNormalization()(g)
g = keras.layers.Activation(activation="elu")(g)
g = keras.layers.Dense(units=80, kernel_initializer="lecun_normal")(g)
g = keras.layers.BatchNormalization()(g)
g = keras.layers.Activation(activation="elu")(g)
g = keras.layers.Dense(units=40, kernel_initializer="lecun_normal")(g)
g = keras.layers.BatchNormalization()(g)
g = keras.layers.Activation(activation="elu")(g)
gen_out = keras.layers.Dense(units=3, activation="sigmoid")(g)

In [None]:
m_generator = keras.Model(inputs=[gen_in], outputs=[gen_out], name="gen")

In [None]:
m_generator.summary()

## Build the discriminator model

In [None]:
disc_in = keras.layers.Input(shape=(218, 178, 3))
d = keras.layers.Conv2D(filters=30, kernel_size=(3), kernel_initializer="lecun_normal", padding="same")(disc_in)
d = keras.layers.Activation(activation="relu")(d)
d = keras.layers.MaxPool2D(pool_size=(2), strides=(2), padding="same")(d)
d = keras.layers.SeparableConv2D(filters=60, kernel_size=(2), kernel_initializer="lecun_normal")(d)
d = keras.layers.Activation(activation="relu")(d)
d = keras.layers.MaxPool2D(pool_size=(2), strides=(2), padding="same")(d)
d = keras.layers.SeparableConv2D(filters=90, kernel_size=(2), kernel_initializer="lecun_normal")(d)
d = keras.layers.Activation(activation="relu")(d)
d = keras.layers.MaxPool2D(pool_size=(2), strides=(2), padding="same")(d)
d = keras.layers.SeparableConv2D(filters=150, kernel_size=(2), kernel_initializer="lecun_normal")(d)
d = keras.layers.Activation(activation="relu")(d)
d = keras.layers.GlobalMaxPool2D()(d)
disc_out = keras.layers.Dense(units=1, activation="sigmoid")(d)

In [None]:
m_discriminator = keras.Model(inputs=[disc_in], outputs=[disc_out], name="disc")

## Customizing the `fit` function

In a GAN, each full training step consists of two individual training steps:

- Training the discriminator
- Training the generator

The steps in training the discriminator are the following:
- We create random, normal distributed noise
- We pass that noise to the generator, who in turn will create the generated samples
- We then create labels for the real and generated samples
- We concatenate the generated samples with the real one
- We then train the generator in a supervised fashion, i.e.
    - passing the concatenated samples to the generator to generate his predictions,
    - calculate the loss based on his predictions and the true labels,
    - calculate the gradients based on this loss and apply them

After that we train the generator in the following way (second part of the full training step):
- We create random, normal distributed noise
- We pass that noise to the generator, who in turn will create the generated samples
- We fix the discriminator weights in place and train the generator in a supervised fashion, i.e.
    - passing the generated samples to the GAN to generate the fake/real predictions
    - calculate the loss based on his predictions and the true labels,
    - calculate the gradients based on this loss and apply them

In [None]:
class MyGAN(keras.Model):
    def __init__(self, generator, discriminator, noise_gen, noise_size=54600, **kwargs):
        super().__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator
        self.noise_gen = noise_gen
        self.noise_size = noise_size
        
    def compile(self, disc_optimizer, gen_optimizer, loss_fn):
        super().compile()
        self.disc_optimizer = disc_optimizer
        self.gen_optimizer = gen_optimizer
        self.loss_fn = loss_fn
    
    def train_step(self, data):
        # Get the data and batch size
        batch_size = data.shape[0]
        
        # Train the discriminator model
        noise = self.noise_gen.normal(shape=[batch_size, self.noise_size])
        X_gen = self.generator(noise)
        y_true = tf.constant([[1.]] * batch_size + [[0.]] * batch_size)
        y_true += 0.05 * tf.random.uniform(shape=tf.shape(y_true))
        X_all = tf.concat([X_gen, data], axis=0)
        with tf.GradientTape() as tape1:
            y_pred = self.discriminator(X_all)
            disc_loss = self.loss_fn(y_true, y_pred)
        disc_vars = self.discriminator.trainable_variables
        disc_gradients = tape1.gradient(disc_loss, disc_vars)
        self.disc_optimizer.apply_gradients(zip(disc_gradients, disc_vars))
        
        # Train the generator model
        noise = self.noise_gen.normal(shape=[batch_size, self.noise_size])
        y_true = tf.constant([[0.]] * batch_size)
        with tf.GradientTape() as tape:
            X_gen = self.generator(noise)
            y_pred = self.discriminator(X_gen)
            gen_loss = self.loss_fn(y_true, y_pred)
        gen_vars = self.generator.trainable_variables
        gen_gradients = tape.gradient(gen_loss, gen_vars)
        self.gen_optimizer.apply_gradients(zip(gen_gradients, gen_vars))
        
        return {"disc_loss" : disc_loss, "gen_loss" : gen_loss}

## Define training parameters and compile model

In [None]:
disc_opti = keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, nesterov=True)
gen_opti = keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, nesterov=True)
loss_fn = keras.losses.BinaryCrossentropy()
noise_gen = tf.random.Generator.from_seed(seed=1224)

In [None]:
gan = MyGAN(generator=m_generator, discriminator=m_discriminator,
            noise_gen=noise_gen, noise_size=noise_size)

In [None]:
gan.compile(disc_optimizer=disc_opti, gen_optimizer=gen_opti, loss_fn=loss_fn)

In [None]:
hist = gan.fit(x=image_dataset, epochs=20, verbose=1)