
# Conditional Generative Adversarial Network (cGAN): A Comprehensive Overview

This notebook provides an in-depth overview of the Conditional Generative Adversarial Network (cGAN) architecture, including its history, mathematical foundation, implementation, usage, advantages and disadvantages, and more. We'll also include visualizations and a discussion of the model's impact and applications.



## History of Conditional Generative Adversarial Networks (cGANs)

Conditional Generative Adversarial Networks (cGANs) were introduced by Mehdi Mirza and Simon Osindero in their 2014 paper "Conditional Generative Adversarial Nets." cGANs are an extension of the traditional GAN model, where both the Generator and Discriminator are conditioned on additional information. This information can be labels, class information, or any other auxiliary data that influences the generation process. The conditional approach allows for more control over the generated data, enabling the creation of specific outputs, such as generating images of specific classes or attributes.



## Mathematical Foundation of Conditional GANs

### Architecture

A Conditional GAN extends the traditional GAN architecture by conditioning both the Generator and Discriminator on some auxiliary information \( y \):

1. **Generator**: The Generator \( G(z|y) \) takes a noise vector \( z \) and the conditional information \( y \) as input and generates synthetic data \( G(z|y) \).

\[
G(z|y) = \text{synthetic data conditioned on } y
\]

2. **Discriminator**: The Discriminator \( D(x|y) \) takes both the data \( x \) (real or synthetic) and the conditional information \( y \) as input and outputs the probability that the data is real.

\[
D(x|y) = P(\text{real} | x, y)
\]

### Loss Function

The objective of cGANs is to solve the following minimax problem:

\[
\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_\text{data}(x)}[\log D(x|y)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z|y)|y))]
\]

- The Discriminator tries to maximize the probability of correctly distinguishing real data from fake data, given the condition \( y \).
- The Generator tries to minimize the probability that the Discriminator correctly classifies the fake data as fake, conditioned on \( y \).

### Training

Training a cGAN involves alternating between optimizing the Discriminator and Generator using gradient descent, with both networks conditioned on the auxiliary information \( y \). The conditional information provides the model with additional context, enabling it to generate data with specific attributes.



## Implementation in Python

We'll implement a Conditional Generative Adversarial Network (cGAN) using TensorFlow and Keras on the MNIST dataset, where the Generator will be conditioned on the digit labels to generate specific digits.


In [None]:

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

# Load and preprocess the MNIST dataset
(x_train, y_train), (_, _) = mnist.load_data()
x_train = (x_train.astype('float32') - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1)
y_train = tf.keras.utils.to_categorical(y_train, 10)

# Generator model
def build_generator():
    noise = layers.Input(shape=(100,))
    label = layers.Input(shape=(10,))
    model_input = layers.Concatenate()([noise, label])

    x = layers.Dense(256)(model_input)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Dense(512)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Dense(28 * 28 * 1, activation='tanh')(x)
    img = layers.Reshape((28, 28, 1))(x)

    model = models.Model([noise, label], img)
    return model

# Discriminator model
def build_discriminator():
    img = layers.Input(shape=(28, 28, 1))
    label = layers.Input(shape=(10,))
    label_layer = layers.Dense(28 * 28 * 1)(label)
    label_layer = layers.Reshape((28, 28, 1))(label_layer)

    model_input = layers.Concatenate()([img, label_layer])

    x = layers.Flatten()(model_input)
    x = layers.Dense(512)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dense(256)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    validity = layers.Dense(1, activation='sigmoid')(x)

    model = models.Model([img, label], validity)
    return model

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Build the generator
generator = build_generator()

# Create the GAN model
noise = layers.Input(shape=(100,))
label = layers.Input(shape=(10,))
img = generator([noise, label])
discriminator.trainable = False
valid = discriminator([img, label])

cgan = models.Model([noise, label], valid)
cgan.compile(loss='binary_crossentropy', optimizer='adam')

# Training the cGAN
def train_cgan(epochs, batch_size=128, save_interval=200):
    half_batch = int(batch_size / 2)

    for epoch in range(epochs):

        # Train Discriminator
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        imgs, labels = x_train[idx], y_train[idx]

        noise = np.random.normal(0, 1, (half_batch, 100))
        gen_labels = np.random.randint(0, 10, half_batch)
        gen_labels = tf.keras.utils.to_categorical(gen_labels, 10)
        gen_imgs = generator.predict([noise, gen_labels])

        d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((half_batch, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train Generator
        noise = np.random.normal(0, 1, (batch_size, 100))
        valid_y = np.ones((batch_size, 1))
        sampled_labels = np.random.randint(0, 10, batch_size)
        sampled_labels = tf.keras.utils.to_categorical(sampled_labels, 10)

        g_loss = cgan.train_on_batch([noise, sampled_labels], valid_y)

        if epoch % save_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]} | D accuracy: {100*d_loss[1]}] [G loss: {g_loss}]")
            save_imgs(epoch)

def save_imgs(epoch):
    r, c = 2, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    sampled_labels = np.arange(0, 10).reshape(-1, 1)
    sampled_labels = tf.keras.utils.to_categorical(sampled_labels, 10)
    gen_imgs = generator.predict([noise, sampled_labels])

    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].set_title(f"Digit: {cnt}")
            axs[i, j].axis('off')
            cnt += 1
    plt.show()

train_cgan(epochs=10000, batch_size=64, save_interval=1000)



## Pros and Cons of Conditional GANs

### Advantages
- **Controlled Data Generation**: cGANs allow for controlled generation of data by conditioning the output on auxiliary information, making them suitable for tasks like image-to-image translation and text-to-image synthesis.
- **Versatility**: The conditional framework can be applied to a wide range of applications, providing more flexibility than traditional GANs.

### Disadvantages
- **Complexity**: cGANs add complexity to the training process due to the additional conditioning information, which may require more careful tuning.
- **Training Instability**: Like traditional GANs, cGANs can suffer from training instability and mode collapse, which may be exacerbated by the conditional information.



## Conclusion

Conditional Generative Adversarial Networks (cGANs) provide a powerful extension to the GAN framework by allowing for controlled data generation based on auxiliary information. This makes them highly valuable in tasks where specific attributes of the generated data need to be controlled. While cGANs inherit some of the challenges of traditional GANs, their versatility and ability to generate conditioned data make them a crucial tool in modern machine learning.
