## Conditional Generative Adverserial Networks

In [10]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

In [11]:
# 1. Load MNIST dataset
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = (x_train.astype("float32") - 127.5) / 127.5  # Scale to [-1, 1]
x_train = np.expand_dims(x_train, axis=-1)
num_classes = 10
latent_dim = 100


In [12]:
# 2. Build Generator
def build_generator():
    label_input = layers.Input(shape=(1,), dtype="int32")
    label_embedding = layers.Embedding(num_classes, 50)(label_input)
    label_embedding = layers.Flatten()(label_embedding)

    noise_input = layers.Input(shape=(latent_dim,))
    model_input = layers.Concatenate()([noise_input, label_embedding])

    x = layers.Dense(7 * 7 * 256, activation="relu")(model_input)
    x = layers.Reshape((7, 7, 256))(x)
    x = layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding="same", activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding="same", activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding="same", activation="tanh")(x)

    model = models.Model([noise_input, label_input], x)
    return model

In [13]:
# 3. Build Discriminator
def build_discriminator():
    img_input = layers.Input(shape=(28, 28, 1))

    label_input = layers.Input(shape=(1,), dtype="int32")
    label_embedding = layers.Embedding(num_classes, 50)(label_input)
    label_embedding = layers.Flatten()(label_embedding)
    label_embedding = layers.Dense(28 * 28)(label_embedding)
    label_embedding = layers.Reshape((28, 28, 1))(label_embedding)

    merged = layers.Concatenate(axis=-1)([img_input, label_embedding])

    x = layers.Conv2D(64, (5, 5), strides=(2, 2), padding="same")(merged)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same")(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1, activation="sigmoid")(x)

    model = models.Model([img_input, label_input], x)
    return model

In [14]:
# 4. Compile Models
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
                      loss="binary_crossentropy",
                      metrics=["accuracy"])

# Combined model
noise = layers.Input(shape=(latent_dim,))
label = layers.Input(shape=(1,), dtype="int32")
img = generator([noise, label])

discriminator.trainable = False
valid = discriminator([img, label])

cgan = models.Model([noise, label], valid)
cgan.compile(optimizer=tf.keras.optimizers.Adam(0.0002, 0.5), loss="binary_crossentropy")

In [15]:
# 5. Save Generated Images
def save_images(epoch):
    r, c = 2, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    labels = np.array([num for _ in range(r) for num in range(c)])
    gen_imgs = generator.predict([noise, labels])

    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale to [0,1]

    fig, axs = plt.subplots(r, c, figsize=(c, r))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap="gray")
            axs[i, j].set_title(f"Digit: {labels[cnt]}")
            axs[i, j].axis("off")
            cnt += 1
    plt.suptitle(f"Generated images at epoch {epoch}")
    plt.show()

# Training Loop
def train(epochs, batch_size=128, save_interval=1000):
    half_batch = batch_size // 2

    for epoch in range(1, epochs + 1):
        # Train Discriminator
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        imgs, labels = x_train[idx], y_train[idx]

        noise = np.random.normal(0, 1, (half_batch, latent_dim))
        gen_labels = np.random.randint(0, num_classes, half_batch)
        gen_imgs = generator.predict([noise, gen_labels])

        d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((half_batch, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train Generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        sampled_labels = np.random.randint(0, num_classes, batch_size)
        g_loss = cgan.train_on_batch([noise, sampled_labels], np.ones((batch_size, 1)))

        # Print progress
        if epoch % 100 == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")

        # Save generated images
        if epoch % save_interval == 0:
            save_images(epoch)

In [16]:
# 6. Save Generated Images
def save_images(epoch):
    r, c = 2, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    labels = np.array([num for _ in range(r) for num in range(c)])
    gen_imgs = generator.predict([noise, labels])

    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale to [0,1]

    fig, axs = plt.subplots(r, c, figsize=(c, r))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap="gray")
            axs[i, j].set_title(f"Digit: {labels[cnt]}")
            axs[i, j].axis("off")
            cnt += 1
    plt.suptitle(f"Generated images at epoch {epoch}")
    plt.show()


In [17]:
# 7. Run Training
train(epochs=100, batch_size=64, save_interval=2000)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 486ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 119ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 136ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 91ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 62ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 97ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 72ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 60ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m