# Wasserstein GAN with Gradient Penalty (WGAN-GP)

__Objective:__ explore WGAN-GP models.

__Source:__ [notebook](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/04_gan/02_wgan_gp/wgan_gp.ipynb) (in turns inspired by this [Keras example](https://keras.io/examples/generative/wgan_gp/)).

In [None]:
import os
import sys
import datetime
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('../modules/')

from utils import (preprocess_image, inverse_preprocessing,
    get_latest_model_logs_dir, select_model_logs_dir, DATETIME_FORMAT)
from wgan_gp_critic import Critic
from generator import Generator
from wgan_gp import WGANGP

sys.path.append('../../bayesian-explorations/modules/')

from keras_utilities import append_to_full_history, plot_history

sns.set_theme()

%load_ext autoreload
%autoreload 2

In [None]:
DATA_DIR = '../data/dataset/'
LOGS_DIR = '../logs/'

## Load data

In [None]:
training_data = tf.keras.utils.image_dataset_from_directory(
    directory=DATA_DIR,
    labels=None,
    color_mode='grayscale',
    batch_size=128,
    image_size=(64, 64),
)

training_data = training_data.repeat()

# Preprocess the data.
training_data = training_data.map(lambda img: preprocess_image(img))

In [None]:
n_images = 3

fig, axs = plt.subplots(nrows=1, ncols=n_images, figsize=(14, 6))

image_batch = next(iter(training_data))

for i in range(n_images):
    axs[i].imshow(
        image_batch[i, ...].numpy(),
        cmap='gray'
    )

    axs[i].grid(False)

## Instantiate the critic part of the model

In [None]:
critic = Critic()

In [None]:
# Test the forward pass.
test_batch = next(iter(training_data))

critic(test_batch[:5])

In [None]:
critic.summary()

## Instantiate the generator part of the model

The generator architecture is the same as the one used for the usual GAN.

In [None]:
generator = Generator()

In [None]:
# Test generating images (untrained generator).
n_images = 3

fig, axs = plt.subplots(nrows=1, ncols=n_images, figsize=(14, 6))

generated_images = inverse_preprocessing(generator(tf.random.normal(shape=(3, 100))))

for i in range(n_images):
    axs[i].imshow(
        generated_images[i, ...].numpy(),
        cmap='gray'
    )

    axs[i].grid(False)

In [None]:
generator(tf.random.normal(shape=(3, 100)))

generator.summary()

# Full WGAN-GP model

Create and train the full WGAN-GP model.

In [None]:
wgangp_model = WGANGP(
    critic=critic,
    generator=generator,
    latent_dim=100,
    critic_steps=3,
    gp_weight=10.
)

# Create empty `History` object (with training history data to be
# appended to it).
full_history = tf.keras.callbacks.History()

# Create a model logs directory for Tensorboard.
model_logs_dir = select_model_logs_dir(LOGS_DIR, append_to_latest_logs=True)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=model_logs_dir)

In [None]:
wgangp_model.compile(
    c_optimizer=tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.999),
    g_optimizer=tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.999)
)

Test fitting the model.

**Warning:** this may take a long time on an average machine!

In [None]:
training_history = wgangp_model.fit(
    training_data,
    epochs=1,
    steps_per_epoch=5,
    callbacks=[tensorboard_callback]
)

append_to_full_history(training_history, full_history)

In [None]:
plot_history(full_history)

Save the model.

In [None]:
saved_model_path = '../models/test_wgan_gp.keras'

wgangp_model.save(saved_model_path)

In [None]:
# saved_model_path = '../models/20231116_155027.keras'

# loaded_model = tf.keras.models.load_model(saved_model_path)

## Generate fake images

In [None]:
n_images = 3

fig, axs = plt.subplots(nrows=1, ncols=n_images, figsize=(14, 6))

images_plot = inverse_preprocessing(wgangp_model.generator(tf.random.normal(shape=(3, 100))))

for i in range(n_images):
    axs[i].imshow(
        images_plot[i, ...].numpy(),
        cmap='gray'
    )

    axs[i].grid(False)