# Wasserstein GAN with Gradient Penalty (WGAN-GP)

__Objective:__ explore WGAN-GP models.

__Source:__ [notebook](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/04_gan/02_wgan_gp/wgan_gp.ipynb) (in turns inspired by this [Keras example](https://keras.io/examples/generative/wgan_gp/)).

In [None]:
import sys
from datetime import datetime, timedelta
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('../modules/')

from wgan_gp_critic import Critic
from generator import Generator
from wgan_gp import WGANGP

sns.set_theme()

%load_ext autoreload
%autoreload 2

In [None]:
DATA_DIR = '../data/dataset/'

## Load data

In [None]:
def preprocess_image(img):
    """
    Rescale pixel intensities (single channel - grayscale)
    to be in the [-1, 1] interval.
    """
    return (tf.cast(img, tf.float32) - 128.) / 128.

In [None]:
training_data = tf.keras.utils.image_dataset_from_directory(
    directory=DATA_DIR,
    labels=None,
    color_mode='grayscale',
    batch_size=128,
    image_size=(64, 64),
)

# Preprocess the data.
training_data = training_data.map(lambda img: preprocess_image(img))

In [None]:
n_images = 3

fig, axs = plt.subplots(nrows=1, ncols=n_images, figsize=(14, 6))

image_batch = next(iter(training_data))

for i in range(n_images):
    axs[i].imshow(
        image_batch[i, ...].numpy(),
        cmap='gray'
    )

    axs[i].grid(False)

## Instantiate the critic part of the model

In [None]:
critic = Critic()

In [None]:
# Test the forward pass.
test_batch = next(iter(training_data))

critic(test_batch[:5])

In [None]:
critic.summary()

## Instantiate the generator part of the model

The generator architecture is the same as the one used for the usual GAN.

In [None]:
generator = Generator()

In [None]:
# Test generating images (untrained generator).
n_images = 3

fig, axs = plt.subplots(nrows=1, ncols=n_images, figsize=(14, 6))

generated_images = generator(tf.random.normal(shape=(3, 100)))

for i in range(n_images):
    axs[i].imshow(
        generated_images[i, ...].numpy(),
        cmap='gray'
    )

    axs[i].grid(False)

In [None]:
generator(tf.random.normal(shape=(3, 100)))

generator.summary()

# Full WGAN-GP model

In [None]:
wgangp_model = WGANGP(
    critic=critic,
    generator=generator,
    latent_dim=100,
    critic_steps=3,
    gp_weight=10
)

In [None]:
wgangp_model.compile(
    c_optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
    g_optimizer=tf.keras.optimizers.SGD(learning_rate=0.01)
)

Test a single training step.

In [None]:
training_step_counter = 0

time_deltas = []
training_history = []

In [None]:
# Note: each training step is performed on one batch of training
# data, so a number (dataset_size / batch_size) of training steps
# corresponds to an epoch.
for i in range(1):
    training_step_counter += 1
    
    t_i = datetime.now()
    
    batch = next(iter(training_data))

    metrics_dict = wgangp_model.train_step(batch)

    t_f = datetime.now()

    time_deltas.append((t_f - t_i) / timedelta(seconds=1.))
    
    training_history.append(metrics_dict)

    print(
        f'Training step: {training_step_counter}'
        f' | Time delta: {time_deltas[-1]}'
        f' | Discriminator loss: {metrics_dict["c_loss"]}'
        f' | Generator loss: {metrics_dict["g_loss"]}'
    )

In [None]:
metrics_history = tf.constant([[metrics['c_loss'].numpy(), metrics['g_loss'].numpy()] for metrics in training_history]).numpy()

fig, axs = plt.subplots(ncols=1, nrows=2, figsize=(14, 6), sharex=True)

sns.lineplot(
    x=range(metrics_history.shape[0]),
    y=metrics_history[:, 0],
    color=sns.color_palette()[0],
    label='Discriminator loss',
    ax=axs[0]
)

plt.sca(axs[0])
plt.title('Losses', fontsize=14)
plt.ylabel('Value')
plt.legend()

sns.lineplot(
    x=range(metrics_history.shape[0]),
    y=metrics_history[:, 1],
    color=sns.color_palette()[1],
    label='Generator loss',
    ax=axs[1]
)

plt.sca(axs[1])
plt.ylabel('Value')
plt.legend()
plt.xlabel('Epoch')
plt.xticks(range(metrics_history.shape[0]))

# Training time distribution.
fig = plt.figure(figsize=(14, 3))

sns.histplot(
    x=time_deltas
)

plt.title('Distribution of times for one training step', fontsize=14)
plt.xlabel('s')

Test fitting the model.

**Warning:** this may take a long time on an average machine!

In [None]:
wgangp_model.fit(
    training_data,
    epochs=1,
    steps_per_epoch=1
)

## Generate fake images

In [None]:
n_images = 3

fig, axs = plt.subplots(nrows=1, ncols=n_images, figsize=(14, 6))

images_plot = generator(tf.random.normal(shape=(3, 100)))

for i in range(n_images):
    axs[i].imshow(
        generated_images[i, ...].numpy(),
        cmap='gray'
    )

    axs[i].grid(False)