# Conditional GANs (CGANs)

__Objective:__ explore conditional GAN models.

__Source:__ [notebook](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/04_gan/03_cgan/cgan.ipynb)

**Idea:** in usual GAN models, the generation process maps a randomly-sampled vector in latent space to a generated sample using the generator part of the model, while the discrimination process uses the discriminator part of the model to estimate the probability that a sample is taken from the real dataset rather than generated. Conditional GANs introduce additional inputs to the generation and discrimination processes, usually in the form of a class label, allowing e.g. to generate samples belonging to a specified class. In particular,
- The generator takes (an encoding of) the class label as an additional input on top of the randomly-generated latent vector, and tries to convert the latent vector itself to a sample resembling those in the training dataset belonging to the class specified by the label. The generator tries to learn the conditional distribution $p(x | z, c)$, where $z$ is the latent vector and $c$ is the class label.
- The discriminator also taks (an encoding of) a class label as an additional input, and tries to predict whether the provided sample comes from the real dataset **and belongs to the specified class**. The discriminator tries to lean the probability $p(\text{real} | x, c)$, where $x$ is the sample and $c$ is again the class label.

In [None]:
import os
import sys
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('../modules/')

from utils import preprocess_image

sns.set_theme()

%load_ext autoreload
%autoreload 2

## Load data

In [None]:
DATA_DIR = '../data/dataset/'

Infer labels from the images' filenames as returned by `os.walk` (see [documentation for Keras' `image_dataset_from_directory` function](https://keras.io/api/data_loading/image/)).

**Labels choice:** we use the bricks dataset and split it into two classes, roof tiles (tiles with "roof" in their image name label `1`) and all the other tiles (label `0`). Not much imagination here, it's just an example!

In [None]:
image_names = list(os.walk(DATA_DIR))[0][2]

labels = [1 if ('roof' in name) else 0 for name in image_names]

print(
    'Fraction of samples with label 1:',
    tf.constant(labels).numpy().mean()
)

Load the images with the specified labels.

In [None]:
training_data = tf.keras.utils.image_dataset_from_directory(
    DATA_DIR,
    labels=labels,
    color_mode="grayscale",
    batch_size=128,
    image_size=(64, 64)
)

training_data = training_data.map(lambda img, label: (preprocess_image(img), tf.one_hot(label, depth=2)))

Plot some random images belonging to both classes.

In [None]:
images_batch, labels_batch = next(iter(training_data))

In [None]:
n_images = 3

images_plot = tf.stack(
    [
        images_batch[labels_batch[..., 0] == 0][:3, ...],
        images_batch[labels_batch[..., 0] == 1][:3, ...]
    ],
    axis=0
)

fig, axs = plt.subplots(nrows=2, ncols=n_images, figsize=(14, 6))

plt.subplots_adjust(hspace=0.3)

for i in range(2):
    for j in range(n_images):
        ax = axs[i, j]

        ax.imshow(
            images_plot[i, j, ...],
            cmap='gray'
        )

        ax.grid(False)

        plt.sca(ax)
        plt.title(f'Label: {i}')

## Model

### Generator

The generator is a simple adaptation of the usual GAN/WGAN generator model, this time accepting the class label as an additional input.

In [None]:
from generator import CGANGenerator

In [None]:
generator = CGANGenerator()

The generator's input for CGANs is a list with the first element being the randomly-generated latent vector and the second element being the one-hot encoded class labels (these tensors will be concatenated as the first step in the generation process).

In [None]:
# Test the generator.
generator([tf.random.normal(shape=(1, 100)), labels_batch[:1, ...]])

In [None]:
generator.summary()

### Discriminator (critic)

The discriminator is also an adaptation of the GAN/WGAN one, again accepting the class label as an additional input. Because images are higher-rank tensors w.r.t. the one-hot encoded class labels, their concatenation is not as straightforward as in the generator: indeed, the class label tensors' dimensions are expanded and the values repeated along the new axes (see the implementation in the module for more details). This reshaping happens **before** everything is passed to the critic and is implemented as a static method of the full model class (`ConditionalWGANGP`).

In [None]:
from wgan_gp_critic import CGANCritic
from wgan_gp import ConditionalWGANGP

In [None]:
critic = CGANCritic()

In [None]:
ConditionalWGANGP.expand_label_tensor(labels_batch).shape

In [None]:
# Test prediction.
critic([images_batch[:1, ...], ConditionalWGANGP.expand_label_tensor(labels_batch[:1, ...])])

In [None]:
critic.summary()

### Full conditional WGAN-GP model

In [None]:
cwgan_gp_model = ConditionalWGANGP(
    critic=critic,
    generator=generator,
    latent_dim=100,
    critic_steps=3,
    gp_weight=10.
)

cwgan_gp_model.compile(
    c_optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
    g_optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
)

Test a training step.

In [None]:
cwgan_gp_model.train_step([images_batch, labels_batch])

In [None]:
cwgan_gp_model.fit(
    training_data,
    epochs=1,
    steps_per_epoch=1
)