# Generative Adversarial Network (GAN) on CIFAR-10

This notebook demonstrates how to build and train a simple GAN for image generation using the CIFAR-10 dataset.

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

## Load and Normalize CIFAR-10 Data

In [None]:
def load_real_images():
    (x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()
    x_train = (x_train - 127.5) / 127.5
    return x_train

## Build Generator

In [None]:
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_shape=(100,)))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(1024))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(32 * 32 * 3, activation='tanh'))
    model.add(layers.Reshape((32, 32, 3)))
    return model

## Build Discriminator

In [None]:
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Flatten(input_shape=(32, 32, 3)))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

## Compile GAN

In [None]:
def compile_gan(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    discriminator.trainable = False
    gan_input = layers.Input(shape=(100,))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)
    gan = tf.keras.models.Model(gan_input, gan_output)
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan

## Training Loop

In [None]:
def train_gan(generator, discriminator, gan, epochs=1000, batch_size=64, save_interval=200):
    x_train = load_real_images()
    half_batch = batch_size // 2
    for epoch in range(epochs):
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        real_images = x_train[idx]
        noise = np.random.normal(0, 1, (half_batch, 100))
        fake_images = generator.predict(noise)
        d_loss_real = discriminator.train_on_batch(real_images, np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((half_batch, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        noise = np.random.normal(0, 1, (batch_size, 100))
        valid_y = np.array([1] * batch_size)
        g_loss = gan.train_on_batch(noise, valid_y)
        if epoch % save_interval == 0:
            print(f"Epoch {epoch} / {epochs} [D loss: {d_loss[0]}] [G loss: {g_loss}]")
            save_images(generator, epoch)

## Save Generated Images

In [None]:
def save_images(generator, epoch, examples=10, dim=(1, 10), figsize=(10, 1)):
    noise = np.random.normal(0, 1, (examples, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5
    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i])
        plt.axis('off')
    plt.tight_layout()
    if not os.path.exists("gan_images"):
        os.makedirs("gan_images")
    plt.savefig(f"gan_images/gan_image_{epoch}.png")
    plt.close()

## Build and Train the GAN

In [None]:
generator = build_generator()
discriminator = build_discriminator()
gan = compile_gan(generator, discriminator)
train_gan(generator, discriminator, gan)