Jane Howard

Gen AI

Assignment 4

Theory Questions

Q1: GANs are formulated as a two-player minimax game between a generator G and a discriminator D. In the minimax objective function, the discriminator D aims to maximize the probability of correctly classifying real samples as real and generated sample as fake. The generator G is trained to fool the discriminator by producing samples that the discriminator is not able to distiniguish from real data. The minimax structure ensures competitive training as imporvments in one of the networks directly challenges the other network. For example, as the discriminator becomes better at identifying the fake samples, the generator better learns to produce more realistic data to fool it.


Q2: Mode collapse is a failure mode in GAN training in which the generator only produces a limited variety of outputs. Often the outputs repeat very similar or the same samples instead of capturing differences in the data distribution. Mode collapse can happen because the generator is trained solely to fool the discriminator rather than to fully capture the diversity of the realistic data. If a small set of outputs consistently fools the discriminator, the generator can repeatedly produce those outputs and ignore other data modes.

There are many techniques that have been proposed to mitigate mode collapse such as batch normalization, minibatch discrimination, and WGAN. These techniques help to stabilize the training process and encourage variety in generated samples.


Q3: In GAN training, the discriminator acts as a learned loss function for the generator. The discriminators primary function is to distinguish between real data  from the dataset and fake samples created by the generator. During training the discriminator provides gradients to the generator that indicate how realistic the generated samples are. If the discriminator is well trained, the gradients will be informative and guide the generator toward producing more realistic data. However, if the discriminator is weak the generator may produce low quality samples. Additionally, the discriminator can be too strong which could result in vanishing gradients for the generator.


Q4: The Inception Score (IS) and Fréchet Inception Distance (FID) are commonly used metrics for evaluating the quality of GAN generated samples. The IS measures the quality and diversity of generated images by using a pretrained inception network. The higher quality images have more confident class predictions and the diverse outputs have a broder distribution of predicted classes. IS does not compare generated samples to real data which can be misleading.

The FID compares statistical similarity between real and generated images in the feature space of a pretrained network. FID computes the distance between the mean and covariance of real and generated images. A lower FID score indicates that the generated images are closer to the real data distribution.


In [1]:
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt

# config
EPOCHS = 50
NOISE_DIM = 128
BATCH_SIZE = 256
BUFFER_SIZE = 50000
SAVE_EVERY = 10
OUTPUT_DIR = "gan_cifar10_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# data CIFAR10
(train_images, _), (_, _) = tf.keras.datasets.cifar10.load_data()
train_images = train_images.astype("float32")
train_images = (train_images - 127.5) / 127.5

train_dataset = (
    tf.data.Dataset.from_tensor_slices(train_images)
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step


In [2]:
#Generator model
def make_generator_model(noise_dim=NOISE_DIM):
    # output: 32x32x3 with tanh
    model = tf.keras.Sequential(name="generator")

    model.add(layers.Input(shape=(noise_dim,)))
    model.add(layers.Dense(4 * 4 * 512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Reshape((4, 4, 512)))  #4x4x512

    model.add(layers.Conv2DTranspose(256, (4, 4), strides=(2, 2), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())  #8x8x256

    model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())  #16x16x128

    model.add(layers.Conv2DTranspose(128, (3, 3), strides=(1, 1), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())  #16x16x128

    model.add(layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())  #32x32x64

    model.add(layers.Conv2DTranspose(64, (3, 3), strides=(1, 1), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())  #32x32x64

    model.add(layers.Conv2DTranspose(3, (3, 3), strides=(1, 1), padding="same", use_bias=False, activation="tanh"))
    return model

In [3]:
#Discriminator model
def make_discriminator_model():
    model = tf.keras.Sequential(name="discriminator")
    model.add(layers.Input(shape=(32, 32, 3)))

    model.add(layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same"))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same"))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(256, (4, 4), strides=(2, 2), padding="same"))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model


generator = make_generator_model()
discriminator = make_discriminator_model()


In [4]:
#losses and optimizers

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_logits, fake_logits):
    real_loss = cross_entropy(tf.ones_like(real_logits), real_logits)
    fake_loss = cross_entropy(tf.zeros_like(fake_logits), fake_logits)
    return real_loss + fake_loss

def generator_loss(fake_logits):
    return cross_entropy(tf.ones_like(fake_logits), fake_logits)

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)

num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate, NOISE_DIM])

# image saving
def save_images(model, epoch, test_input, out_dir=OUTPUT_DIR):
    predictions = model(test_input, training=False)
    images = (predictions + 1.0) / 2.0
    images = tf.clip_by_value(images, 0.0, 1.0).numpy()

    fig = plt.figure(figsize=(4, 4))
    for i in range(images.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(images[i])
        plt.axis("off")

    path = os.path.join(out_dir, f"epoch_{epoch:03d}.png")
    plt.tight_layout()
    plt.savefig(path, dpi=150)
    plt.close(fig)

In [5]:
# Training function
@tf.function
def train_step(real_images):
    batch_size = tf.shape(real_images)[0]
    noise = tf.random.normal([batch_size, NOISE_DIM])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        fake_images = generator(noise, training=True)

        real_logits = discriminator(real_images, training=True)
        fake_logits = discriminator(fake_images, training=True)

        gen_loss = generator_loss(fake_logits)
        disc_loss = discriminator_loss(real_logits, fake_logits)

    gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

    return gen_loss, disc_loss


In [6]:
#Training loop
def train(dataset, epochs):
    save_images(generator, 0, seed)
    for epoch in range(1, epochs + 1):
        for batch in dataset:
            train_step(batch)

        print(f"Epoch {epoch}/{epochs} completed")

        if epoch % SAVE_EVERY == 0:
            save_images(generator, epoch, seed)


In [7]:
train(train_dataset, EPOCHS)
print("Done. Check gan_cifar10_outputs for images.")


Epoch 1/50 completed
Epoch 2/50 completed
Epoch 3/50 completed
Epoch 4/50 completed
Epoch 5/50 completed
Epoch 6/50 completed
Epoch 7/50 completed
Epoch 8/50 completed
Epoch 9/50 completed
Epoch 10/50 completed
Epoch 11/50 completed
Epoch 12/50 completed
Epoch 13/50 completed
Epoch 14/50 completed
Epoch 15/50 completed
Epoch 16/50 completed
Epoch 17/50 completed
Epoch 18/50 completed
Epoch 19/50 completed
Epoch 20/50 completed
Epoch 21/50 completed
Epoch 22/50 completed
Epoch 23/50 completed
Epoch 24/50 completed
Epoch 25/50 completed
Epoch 26/50 completed
Epoch 27/50 completed
Epoch 28/50 completed
Epoch 29/50 completed
Epoch 30/50 completed
Epoch 31/50 completed
Epoch 32/50 completed
Epoch 33/50 completed
Epoch 34/50 completed
Epoch 35/50 completed
Epoch 36/50 completed
Epoch 37/50 completed
Epoch 38/50 completed
Epoch 39/50 completed
Epoch 40/50 completed
Epoch 41/50 completed
Epoch 42/50 completed
Epoch 43/50 completed
Epoch 44/50 completed
Epoch 45/50 completed
Epoch 46/50 complet